Skip to main content

timer_lib/
registry.rs

1use std::collections::HashMap;
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9
10use crate::errors::TimerError;
11use crate::timer::{Timer, TimerCallback, TimerOutcome, TimerState};
12
13/// A registry for tracking timers by identifier.
14#[derive(Clone, Default)]
15pub struct TimerRegistry {
16    timers: Arc<RwLock<HashMap<u64, Timer>>>,
17    next_id: Arc<AtomicU64>,
18}
19
20impl TimerRegistry {
21    /// Creates a new timer registry.
22    pub fn new() -> Self {
23        Self {
24            timers: Arc::new(RwLock::new(HashMap::new())),
25            next_id: Arc::new(AtomicU64::new(0)),
26        }
27    }
28
29    /// Inserts an existing timer and returns its identifier.
30    pub async fn insert(&self, timer: Timer) -> u64 {
31        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
32        self.timers.write().await.insert(id, timer);
33        id
34    }
35
36    /// Starts and registers a one-time timer.
37    pub async fn start_once<F>(
38        &self,
39        delay: Duration,
40        callback: F,
41    ) -> Result<(u64, Timer), TimerError>
42    where
43        F: TimerCallback + 'static,
44    {
45        let timer = Timer::new();
46        let _ = timer.start_once(delay, callback).await?;
47        let id = self.insert(timer.clone()).await;
48        Ok((id, timer))
49    }
50
51    /// Starts and registers a recurring timer.
52    pub async fn start_recurring<F>(
53        &self,
54        interval: Duration,
55        callback: F,
56        expiration_count: Option<usize>,
57    ) -> Result<(u64, Timer), TimerError>
58    where
59        F: TimerCallback + 'static,
60    {
61        let timer = Timer::new();
62        let _ = timer
63            .start_recurring(interval, callback, expiration_count)
64            .await?;
65        let id = self.insert(timer.clone()).await;
66        Ok((id, timer))
67    }
68
69    /// Removes a timer from the registry and returns it.
70    pub async fn remove(&self, id: u64) -> Option<Timer> {
71        self.timers.write().await.remove(&id)
72    }
73
74    /// Returns true when the registry tracks the given timer identifier.
75    pub async fn contains(&self, id: u64) -> bool {
76        self.timers.read().await.contains_key(&id)
77    }
78
79    /// Stops a timer by identifier when it exists.
80    pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
81        let timer = self.get(id).await;
82        match timer {
83            Some(timer) => timer.stop().await.map(Some),
84            None => Ok(None),
85        }
86    }
87
88    /// Cancels a timer by identifier when it exists.
89    pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
90        let timer = self.get(id).await;
91        match timer {
92            Some(timer) => timer.cancel().await.map(Some),
93            None => Ok(None),
94        }
95    }
96
97    /// Pauses a timer by identifier when it exists.
98    pub async fn pause(&self, id: u64) -> Result<bool, TimerError> {
99        let timer = self.get(id).await;
100        match timer {
101            Some(timer) => {
102                timer.pause().await?;
103                Ok(true)
104            }
105            None => Ok(false),
106        }
107    }
108
109    /// Resumes a timer by identifier when it exists.
110    pub async fn resume(&self, id: u64) -> Result<bool, TimerError> {
111        let timer = self.get(id).await;
112        match timer {
113            Some(timer) => {
114                timer.resume().await?;
115                Ok(true)
116            }
117            None => Ok(false),
118        }
119    }
120
121    /// Stops all timers currently tracked by the registry.
122    pub async fn stop_all(&self) {
123        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
124        for timer in timers {
125            let _ = timer.stop().await;
126        }
127    }
128
129    /// Pauses all running timers currently tracked by the registry.
130    pub async fn pause_all(&self) {
131        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
132        for timer in timers {
133            let _ = timer.pause().await;
134        }
135    }
136
137    /// Waits for all tracked timers that have a joinable outcome.
138    pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
139        let timers: Vec<(u64, Timer)> = self
140            .timers
141            .read()
142            .await
143            .iter()
144            .map(|(id, timer)| (*id, timer.clone()))
145            .collect();
146
147        let mut outcomes = Vec::with_capacity(timers.len());
148        for (id, timer) in timers {
149            if let Ok(outcome) = timer.join().await {
150                outcomes.push((id, outcome));
151            }
152        }
153
154        outcomes
155    }
156
157    /// Cancels all timers currently tracked by the registry.
158    pub async fn cancel_all(&self) {
159        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
160        for timer in timers {
161            let _ = timer.cancel().await;
162        }
163    }
164
165    /// Resumes all paused timers currently tracked by the registry.
166    pub async fn resume_all(&self) {
167        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
168        for timer in timers {
169            let _ = timer.resume().await;
170        }
171    }
172
173    /// Lists all active timers.
174    pub async fn active_ids(&self) -> Vec<u64> {
175        let timers: Vec<(u64, Timer)> = self
176            .timers
177            .read()
178            .await
179            .iter()
180            .map(|(id, timer)| (*id, timer.clone()))
181            .collect();
182
183        let mut active = Vec::new();
184        for (id, timer) in timers {
185            if timer.get_state().await != TimerState::Stopped {
186                active.push(id);
187            }
188        }
189        active
190    }
191
192    /// Retrieves a timer by ID.
193    pub async fn get(&self, id: u64) -> Option<Timer> {
194        self.timers.read().await.get(&id).cloned()
195    }
196
197    /// Returns the number of tracked timers.
198    pub async fn len(&self) -> usize {
199        self.timers.read().await.len()
200    }
201
202    /// Returns true when the registry is empty.
203    pub async fn is_empty(&self) -> bool {
204        self.len().await == 0
205    }
206
207    /// Removes all tracked timers and returns the number removed.
208    pub async fn clear(&self) -> usize {
209        let mut timers = self.timers.write().await;
210        let removed = timers.len();
211        timers.clear();
212        removed
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::timer::TimerFinishReason;
220    use tokio::task::yield_now;
221    use tokio::time::advance;
222
223    async fn settle() {
224        for _ in 0..5 {
225            yield_now().await;
226        }
227    }
228
229    #[tokio::test(flavor = "current_thread", start_paused = true)]
230    async fn registry_start_helpers_are_easy_to_use() {
231        let registry = TimerRegistry::new();
232        let (once_id, once_timer) = registry
233            .start_once(Duration::from_secs(1), || async { Ok(()) })
234            .await
235            .unwrap();
236        let (recurring_id, recurring_timer) = registry
237            .start_recurring(Duration::from_secs(2), || async { Ok(()) }, None)
238            .await
239            .unwrap();
240
241        assert_ne!(once_id, recurring_id);
242        assert_eq!(registry.len().await, 2);
243        assert!(registry.get(once_id).await.is_some());
244
245        advance(Duration::from_secs(1)).await;
246        settle().await;
247        assert_eq!(
248            once_timer.join().await.unwrap().reason,
249            crate::timer::TimerFinishReason::Completed
250        );
251
252        let active = registry.active_ids().await;
253        assert!(active.contains(&recurring_id));
254
255        let _ = recurring_timer.cancel().await.unwrap();
256    }
257
258    #[tokio::test(flavor = "current_thread", start_paused = true)]
259    async fn registry_supports_direct_timer_controls() {
260        let registry = TimerRegistry::new();
261        let (timer_id, _timer) = registry
262            .start_once(Duration::from_secs(5), || async { Ok(()) })
263            .await
264            .unwrap();
265
266        assert!(registry.contains(timer_id).await);
267        let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
268        assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
269        assert_eq!(registry.clear().await, 1);
270        assert!(registry.is_empty().await);
271    }
272
273    #[tokio::test(flavor = "current_thread", start_paused = true)]
274    async fn registry_can_pause_and_resume_tracked_timers() {
275        let registry = TimerRegistry::new();
276        let (timer_id, timer) = registry
277            .start_recurring(Duration::from_secs(2), || async { Ok(()) }, Some(1))
278            .await
279            .unwrap();
280        settle().await;
281
282        assert!(registry.pause(timer_id).await.unwrap());
283        assert_eq!(timer.get_state().await, TimerState::Paused);
284
285        advance(Duration::from_secs(5)).await;
286        settle().await;
287        assert_eq!(timer.get_statistics().await.execution_count, 0);
288
289        assert!(registry.resume(timer_id).await.unwrap());
290        advance(Duration::from_secs(2)).await;
291        settle().await;
292        assert_eq!(
293            timer.join().await.unwrap().reason,
294            TimerFinishReason::Completed
295        );
296    }
297}