1use std::collections::HashMap;
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9
10use crate::errors::TimerError;
11use crate::timer::{RecurringSchedule, Timer, TimerCallback, TimerOutcome, TimerState};
12
13#[derive(Clone, Default)]
15pub struct TimerRegistry {
16 timers: Arc<RwLock<HashMap<u64, Timer>>>,
17 next_id: Arc<AtomicU64>,
18}
19
20impl TimerRegistry {
21 pub fn new() -> Self {
23 Self {
24 timers: Arc::new(RwLock::new(HashMap::new())),
25 next_id: Arc::new(AtomicU64::new(0)),
26 }
27 }
28
29 pub async fn insert(&self, timer: Timer) -> u64 {
31 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
32 self.timers.write().await.insert(id, timer);
33 id
34 }
35
36 pub async fn start_once<F>(
38 &self,
39 delay: Duration,
40 callback: F,
41 ) -> Result<(u64, Timer), TimerError>
42 where
43 F: TimerCallback + 'static,
44 {
45 let timer = Timer::new();
46 let _ = timer.start_once(delay, callback).await?;
47 let id = self.insert(timer.clone()).await;
48 Ok((id, timer))
49 }
50
51 pub async fn start_recurring<F>(
53 &self,
54 schedule: RecurringSchedule,
55 callback: F,
56 ) -> Result<(u64, Timer), TimerError>
57 where
58 F: TimerCallback + 'static,
59 {
60 let timer = Timer::new();
61 let _ = timer.start_recurring(schedule, callback).await?;
62 let id = self.insert(timer.clone()).await;
63 Ok((id, timer))
64 }
65
66 pub async fn remove(&self, id: u64) -> Option<Timer> {
68 self.timers.write().await.remove(&id)
69 }
70
71 pub async fn contains(&self, id: u64) -> bool {
73 self.timers.read().await.contains_key(&id)
74 }
75
76 pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
78 let timer = self.get(id).await;
79 match timer {
80 Some(timer) => timer.stop().await.map(Some),
81 None => Ok(None),
82 }
83 }
84
85 pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
87 let timer = self.get(id).await;
88 match timer {
89 Some(timer) => timer.cancel().await.map(Some),
90 None => Ok(None),
91 }
92 }
93
94 pub async fn pause(&self, id: u64) -> Result<bool, TimerError> {
96 let timer = self.get(id).await;
97 match timer {
98 Some(timer) => {
99 timer.pause().await?;
100 Ok(true)
101 }
102 None => Ok(false),
103 }
104 }
105
106 pub async fn resume(&self, id: u64) -> Result<bool, TimerError> {
108 let timer = self.get(id).await;
109 match timer {
110 Some(timer) => {
111 timer.resume().await?;
112 Ok(true)
113 }
114 None => Ok(false),
115 }
116 }
117
118 pub async fn stop_all(&self) {
120 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
121 for timer in timers {
122 let _ = timer.stop().await;
123 }
124 }
125
126 pub async fn pause_all(&self) {
128 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
129 for timer in timers {
130 let _ = timer.pause().await;
131 }
132 }
133
134 pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
136 let timers: Vec<(u64, Timer)> = self
137 .timers
138 .read()
139 .await
140 .iter()
141 .map(|(id, timer)| (*id, timer.clone()))
142 .collect();
143
144 let mut outcomes = Vec::with_capacity(timers.len());
145 for (id, timer) in timers {
146 if let Ok(outcome) = timer.join().await {
147 outcomes.push((id, outcome));
148 }
149 }
150
151 outcomes
152 }
153
154 pub async fn cancel_all(&self) {
156 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
157 for timer in timers {
158 let _ = timer.cancel().await;
159 }
160 }
161
162 pub async fn resume_all(&self) {
164 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
165 for timer in timers {
166 let _ = timer.resume().await;
167 }
168 }
169
170 pub async fn active_ids(&self) -> Vec<u64> {
172 let timers: Vec<(u64, Timer)> = self
173 .timers
174 .read()
175 .await
176 .iter()
177 .map(|(id, timer)| (*id, timer.clone()))
178 .collect();
179
180 let mut active = Vec::new();
181 for (id, timer) in timers {
182 if timer.get_state().await != TimerState::Stopped {
183 active.push(id);
184 }
185 }
186 active
187 }
188
189 pub async fn get(&self, id: u64) -> Option<Timer> {
191 self.timers.read().await.get(&id).cloned()
192 }
193
194 pub async fn len(&self) -> usize {
196 self.timers.read().await.len()
197 }
198
199 pub async fn is_empty(&self) -> bool {
201 self.len().await == 0
202 }
203
204 pub async fn clear(&self) -> usize {
206 let mut timers = self.timers.write().await;
207 let removed = timers.len();
208 timers.clear();
209 removed
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::timer::TimerFinishReason;
217 use tokio::task::yield_now;
218 use tokio::time::advance;
219
220 async fn settle() {
221 for _ in 0..5 {
222 yield_now().await;
223 }
224 }
225
226 #[tokio::test(flavor = "current_thread", start_paused = true)]
227 async fn registry_start_helpers_are_easy_to_use() {
228 let registry = TimerRegistry::new();
229 let (once_id, once_timer) = registry
230 .start_once(Duration::from_secs(1), || async { Ok(()) })
231 .await
232 .unwrap();
233 let (recurring_id, recurring_timer) = registry
234 .start_recurring(RecurringSchedule::new(Duration::from_secs(2)), || async {
235 Ok(())
236 })
237 .await
238 .unwrap();
239
240 assert_ne!(once_id, recurring_id);
241 assert_eq!(registry.len().await, 2);
242 assert!(registry.get(once_id).await.is_some());
243
244 advance(Duration::from_secs(1)).await;
245 settle().await;
246 assert_eq!(
247 once_timer.join().await.unwrap().reason,
248 crate::timer::TimerFinishReason::Completed
249 );
250
251 let active = registry.active_ids().await;
252 assert!(active.contains(&recurring_id));
253
254 let _ = recurring_timer.cancel().await.unwrap();
255 }
256
257 #[tokio::test(flavor = "current_thread", start_paused = true)]
258 async fn registry_supports_direct_timer_controls() {
259 let registry = TimerRegistry::new();
260 let (timer_id, _timer) = registry
261 .start_once(Duration::from_secs(5), || async { Ok(()) })
262 .await
263 .unwrap();
264
265 assert!(registry.contains(timer_id).await);
266 let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
267 assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
268 assert_eq!(registry.clear().await, 1);
269 assert!(registry.is_empty().await);
270 }
271
272 #[tokio::test(flavor = "current_thread", start_paused = true)]
273 async fn registry_can_pause_and_resume_tracked_timers() {
274 let registry = TimerRegistry::new();
275 let (timer_id, timer) = registry
276 .start_recurring(
277 RecurringSchedule::new(Duration::from_secs(2)).with_expiration_count(1),
278 || async { Ok(()) },
279 )
280 .await
281 .unwrap();
282 settle().await;
283
284 assert!(registry.pause(timer_id).await.unwrap());
285 assert_eq!(timer.get_state().await, TimerState::Paused);
286
287 advance(Duration::from_secs(5)).await;
288 settle().await;
289 assert_eq!(timer.get_statistics().await.execution_count, 0);
290
291 assert!(registry.resume(timer_id).await.unwrap());
292 advance(Duration::from_secs(2)).await;
293 settle().await;
294 assert_eq!(
295 timer.join().await.unwrap().reason,
296 TimerFinishReason::Completed
297 );
298 }
299}