1use std::collections::HashMap;
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9
10use crate::errors::TimerError;
11use crate::timer::{Timer, TimerCallback, TimerOutcome, TimerState};
12
13#[derive(Clone, Default)]
15pub struct TimerRegistry {
16 timers: Arc<RwLock<HashMap<u64, Timer>>>,
17 next_id: Arc<AtomicU64>,
18}
19
20impl TimerRegistry {
21 pub fn new() -> Self {
23 Self {
24 timers: Arc::new(RwLock::new(HashMap::new())),
25 next_id: Arc::new(AtomicU64::new(0)),
26 }
27 }
28
29 pub async fn insert(&self, timer: Timer) -> u64 {
31 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
32 self.timers.write().await.insert(id, timer);
33 id
34 }
35
36 pub async fn start_once<F>(
38 &self,
39 delay: Duration,
40 callback: F,
41 ) -> Result<(u64, Timer), TimerError>
42 where
43 F: TimerCallback + 'static,
44 {
45 let timer = Timer::new();
46 let _ = timer.start_once(delay, callback).await?;
47 let id = self.insert(timer.clone()).await;
48 Ok((id, timer))
49 }
50
51 pub async fn start_recurring<F>(
53 &self,
54 interval: Duration,
55 callback: F,
56 expiration_count: Option<usize>,
57 ) -> Result<(u64, Timer), TimerError>
58 where
59 F: TimerCallback + 'static,
60 {
61 let timer = Timer::new();
62 let _ = timer
63 .start_recurring(interval, callback, expiration_count)
64 .await?;
65 let id = self.insert(timer.clone()).await;
66 Ok((id, timer))
67 }
68
69 pub async fn remove(&self, id: u64) -> Option<Timer> {
71 self.timers.write().await.remove(&id)
72 }
73
74 pub async fn contains(&self, id: u64) -> bool {
76 self.timers.read().await.contains_key(&id)
77 }
78
79 pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
81 let timer = self.get(id).await;
82 match timer {
83 Some(timer) => timer.stop().await.map(Some),
84 None => Ok(None),
85 }
86 }
87
88 pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
90 let timer = self.get(id).await;
91 match timer {
92 Some(timer) => timer.cancel().await.map(Some),
93 None => Ok(None),
94 }
95 }
96
97 pub async fn stop_all(&self) {
99 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
100 for timer in timers {
101 let _ = timer.stop().await;
102 }
103 }
104
105 pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
107 let timers: Vec<(u64, Timer)> = self
108 .timers
109 .read()
110 .await
111 .iter()
112 .map(|(id, timer)| (*id, timer.clone()))
113 .collect();
114
115 let mut outcomes = Vec::with_capacity(timers.len());
116 for (id, timer) in timers {
117 if let Ok(outcome) = timer.join().await {
118 outcomes.push((id, outcome));
119 }
120 }
121
122 outcomes
123 }
124
125 pub async fn cancel_all(&self) {
127 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
128 for timer in timers {
129 let _ = timer.cancel().await;
130 }
131 }
132
133 pub async fn active_ids(&self) -> Vec<u64> {
135 let timers: Vec<(u64, Timer)> = self
136 .timers
137 .read()
138 .await
139 .iter()
140 .map(|(id, timer)| (*id, timer.clone()))
141 .collect();
142
143 let mut active = Vec::new();
144 for (id, timer) in timers {
145 if timer.get_state().await != TimerState::Stopped {
146 active.push(id);
147 }
148 }
149 active
150 }
151
152 pub async fn get(&self, id: u64) -> Option<Timer> {
154 self.timers.read().await.get(&id).cloned()
155 }
156
157 pub async fn len(&self) -> usize {
159 self.timers.read().await.len()
160 }
161
162 pub async fn is_empty(&self) -> bool {
164 self.len().await == 0
165 }
166
167 pub async fn clear(&self) -> usize {
169 let mut timers = self.timers.write().await;
170 let removed = timers.len();
171 timers.clear();
172 removed
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::timer::TimerFinishReason;
180 use tokio::task::yield_now;
181 use tokio::time::advance;
182
183 async fn settle() {
184 for _ in 0..5 {
185 yield_now().await;
186 }
187 }
188
189 #[tokio::test(flavor = "current_thread", start_paused = true)]
190 async fn registry_start_helpers_are_easy_to_use() {
191 let registry = TimerRegistry::new();
192 let (once_id, once_timer) = registry
193 .start_once(Duration::from_secs(1), || async { Ok(()) })
194 .await
195 .unwrap();
196 let (recurring_id, recurring_timer) = registry
197 .start_recurring(Duration::from_secs(2), || async { Ok(()) }, None)
198 .await
199 .unwrap();
200
201 assert_ne!(once_id, recurring_id);
202 assert_eq!(registry.len().await, 2);
203 assert!(registry.get(once_id).await.is_some());
204
205 advance(Duration::from_secs(1)).await;
206 settle().await;
207 assert_eq!(
208 once_timer.join().await.unwrap().reason,
209 crate::timer::TimerFinishReason::Completed
210 );
211
212 let active = registry.active_ids().await;
213 assert!(active.contains(&recurring_id));
214
215 let _ = recurring_timer.cancel().await.unwrap();
216 }
217
218 #[tokio::test(flavor = "current_thread", start_paused = true)]
219 async fn registry_supports_direct_timer_controls() {
220 let registry = TimerRegistry::new();
221 let (timer_id, _timer) = registry
222 .start_once(Duration::from_secs(5), || async { Ok(()) })
223 .await
224 .unwrap();
225
226 assert!(registry.contains(timer_id).await);
227 let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
228 assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
229 assert_eq!(registry.clear().await, 1);
230 assert!(registry.is_empty().await);
231 }
232}