1use std::collections::HashMap;
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9
10use crate::errors::TimerError;
11use crate::timer::{Timer, TimerCallback, TimerOutcome, TimerState};
12
13#[derive(Clone, Default)]
15pub struct TimerRegistry {
16 timers: Arc<RwLock<HashMap<u64, Timer>>>,
17 next_id: Arc<AtomicU64>,
18}
19
20impl TimerRegistry {
21 pub fn new() -> Self {
23 Self {
24 timers: Arc::new(RwLock::new(HashMap::new())),
25 next_id: Arc::new(AtomicU64::new(0)),
26 }
27 }
28
29 pub async fn insert(&self, timer: Timer) -> u64 {
31 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
32 self.timers.write().await.insert(id, timer);
33 id
34 }
35
36 pub async fn start_once<F>(
38 &self,
39 delay: Duration,
40 callback: F,
41 ) -> Result<(u64, Timer), TimerError>
42 where
43 F: TimerCallback + 'static,
44 {
45 let timer = Timer::new();
46 let _ = timer.start_once(delay, callback).await?;
47 let id = self.insert(timer.clone()).await;
48 Ok((id, timer))
49 }
50
51 pub async fn start_recurring<F>(
53 &self,
54 interval: Duration,
55 callback: F,
56 expiration_count: Option<usize>,
57 ) -> Result<(u64, Timer), TimerError>
58 where
59 F: TimerCallback + 'static,
60 {
61 let timer = Timer::new();
62 let _ = timer
63 .start_recurring(interval, callback, expiration_count)
64 .await?;
65 let id = self.insert(timer.clone()).await;
66 Ok((id, timer))
67 }
68
69 pub async fn remove(&self, id: u64) -> Option<Timer> {
71 self.timers.write().await.remove(&id)
72 }
73
74 pub async fn contains(&self, id: u64) -> bool {
76 self.timers.read().await.contains_key(&id)
77 }
78
79 pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
81 let timer = self.get(id).await;
82 match timer {
83 Some(timer) => timer.stop().await.map(Some),
84 None => Ok(None),
85 }
86 }
87
88 pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
90 let timer = self.get(id).await;
91 match timer {
92 Some(timer) => timer.cancel().await.map(Some),
93 None => Ok(None),
94 }
95 }
96
97 pub async fn pause(&self, id: u64) -> Result<bool, TimerError> {
99 let timer = self.get(id).await;
100 match timer {
101 Some(timer) => {
102 timer.pause().await?;
103 Ok(true)
104 }
105 None => Ok(false),
106 }
107 }
108
109 pub async fn resume(&self, id: u64) -> Result<bool, TimerError> {
111 let timer = self.get(id).await;
112 match timer {
113 Some(timer) => {
114 timer.resume().await?;
115 Ok(true)
116 }
117 None => Ok(false),
118 }
119 }
120
121 pub async fn stop_all(&self) {
123 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
124 for timer in timers {
125 let _ = timer.stop().await;
126 }
127 }
128
129 pub async fn pause_all(&self) {
131 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
132 for timer in timers {
133 let _ = timer.pause().await;
134 }
135 }
136
137 pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
139 let timers: Vec<(u64, Timer)> = self
140 .timers
141 .read()
142 .await
143 .iter()
144 .map(|(id, timer)| (*id, timer.clone()))
145 .collect();
146
147 let mut outcomes = Vec::with_capacity(timers.len());
148 for (id, timer) in timers {
149 if let Ok(outcome) = timer.join().await {
150 outcomes.push((id, outcome));
151 }
152 }
153
154 outcomes
155 }
156
157 pub async fn cancel_all(&self) {
159 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
160 for timer in timers {
161 let _ = timer.cancel().await;
162 }
163 }
164
165 pub async fn resume_all(&self) {
167 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
168 for timer in timers {
169 let _ = timer.resume().await;
170 }
171 }
172
173 pub async fn active_ids(&self) -> Vec<u64> {
175 let timers: Vec<(u64, Timer)> = self
176 .timers
177 .read()
178 .await
179 .iter()
180 .map(|(id, timer)| (*id, timer.clone()))
181 .collect();
182
183 let mut active = Vec::new();
184 for (id, timer) in timers {
185 if timer.get_state().await != TimerState::Stopped {
186 active.push(id);
187 }
188 }
189 active
190 }
191
192 pub async fn get(&self, id: u64) -> Option<Timer> {
194 self.timers.read().await.get(&id).cloned()
195 }
196
197 pub async fn len(&self) -> usize {
199 self.timers.read().await.len()
200 }
201
202 pub async fn is_empty(&self) -> bool {
204 self.len().await == 0
205 }
206
207 pub async fn clear(&self) -> usize {
209 let mut timers = self.timers.write().await;
210 let removed = timers.len();
211 timers.clear();
212 removed
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::timer::TimerFinishReason;
220 use tokio::task::yield_now;
221 use tokio::time::advance;
222
223 async fn settle() {
224 for _ in 0..5 {
225 yield_now().await;
226 }
227 }
228
229 #[tokio::test(flavor = "current_thread", start_paused = true)]
230 async fn registry_start_helpers_are_easy_to_use() {
231 let registry = TimerRegistry::new();
232 let (once_id, once_timer) = registry
233 .start_once(Duration::from_secs(1), || async { Ok(()) })
234 .await
235 .unwrap();
236 let (recurring_id, recurring_timer) = registry
237 .start_recurring(Duration::from_secs(2), || async { Ok(()) }, None)
238 .await
239 .unwrap();
240
241 assert_ne!(once_id, recurring_id);
242 assert_eq!(registry.len().await, 2);
243 assert!(registry.get(once_id).await.is_some());
244
245 advance(Duration::from_secs(1)).await;
246 settle().await;
247 assert_eq!(
248 once_timer.join().await.unwrap().reason,
249 crate::timer::TimerFinishReason::Completed
250 );
251
252 let active = registry.active_ids().await;
253 assert!(active.contains(&recurring_id));
254
255 let _ = recurring_timer.cancel().await.unwrap();
256 }
257
258 #[tokio::test(flavor = "current_thread", start_paused = true)]
259 async fn registry_supports_direct_timer_controls() {
260 let registry = TimerRegistry::new();
261 let (timer_id, _timer) = registry
262 .start_once(Duration::from_secs(5), || async { Ok(()) })
263 .await
264 .unwrap();
265
266 assert!(registry.contains(timer_id).await);
267 let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
268 assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
269 assert_eq!(registry.clear().await, 1);
270 assert!(registry.is_empty().await);
271 }
272
273 #[tokio::test(flavor = "current_thread", start_paused = true)]
274 async fn registry_can_pause_and_resume_tracked_timers() {
275 let registry = TimerRegistry::new();
276 let (timer_id, timer) = registry
277 .start_recurring(Duration::from_secs(2), || async { Ok(()) }, Some(1))
278 .await
279 .unwrap();
280 settle().await;
281
282 assert!(registry.pause(timer_id).await.unwrap());
283 assert_eq!(timer.get_state().await, TimerState::Paused);
284
285 advance(Duration::from_secs(5)).await;
286 settle().await;
287 assert_eq!(timer.get_statistics().await.execution_count, 0);
288
289 assert!(registry.resume(timer_id).await.unwrap());
290 advance(Duration::from_secs(2)).await;
291 settle().await;
292 assert_eq!(
293 timer.join().await.unwrap().reason,
294 TimerFinishReason::Completed
295 );
296 }
297}