1use std::collections::HashMap;
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9use tokio::time::Instant;
10
11use crate::errors::TimerError;
12use crate::timer::driver::RuntimeHandle;
13use crate::timer::{
14 RecurringSchedule, Timer, TimerCallback, TimerMetadata, TimerOutcome, TimerSnapshot, TimerState,
15};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct RegisteredTimer {
20 pub id: u64,
22 pub state: TimerState,
24 pub interval: Duration,
26 pub expiration_count: Option<usize>,
28 pub statistics: crate::timer::TimerStatistics,
30 pub last_outcome: Option<TimerOutcome>,
32 pub metadata: TimerMetadata,
34}
35
36#[derive(Clone, Default)]
38pub struct TimerRegistry {
39 timers: Arc<RwLock<HashMap<u64, Timer>>>,
40 next_id: Arc<AtomicU64>,
41 runtime: RuntimeHandle,
42}
43
44impl TimerRegistry {
45 pub fn new() -> Self {
47 Self {
48 timers: Arc::new(RwLock::new(HashMap::new())),
49 next_id: Arc::new(AtomicU64::new(0)),
50 runtime: RuntimeHandle::default(),
51 }
52 }
53
54 #[cfg(feature = "test-util")]
56 pub fn new_mocked() -> (Self, crate::timer::MockRuntime) {
57 let runtime = crate::timer::MockRuntime::new();
58 (
59 Self {
60 timers: Arc::new(RwLock::new(HashMap::new())),
61 next_id: Arc::new(AtomicU64::new(0)),
62 runtime: runtime.handle(),
63 },
64 runtime,
65 )
66 }
67
68 pub async fn insert(&self, timer: Timer) -> u64 {
70 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
71 self.timers.write().await.insert(id, timer);
72 id
73 }
74
75 pub async fn start_once<F>(
77 &self,
78 delay: Duration,
79 callback: F,
80 ) -> Result<(u64, Timer), TimerError>
81 where
82 F: TimerCallback + 'static,
83 {
84 let timer = Timer::new_with_runtime(self.runtime.clone(), true);
85 let _ = timer.start_once(delay, callback).await?;
86 let id = self.insert(timer.clone()).await;
87 Ok((id, timer))
88 }
89
90 pub async fn start_at<F>(
92 &self,
93 deadline: Instant,
94 callback: F,
95 ) -> Result<(u64, Timer), TimerError>
96 where
97 F: TimerCallback + 'static,
98 {
99 let timer = Timer::new_with_runtime(self.runtime.clone(), true);
100 let _ = timer.start_at(deadline, callback).await?;
101 let id = self.insert(timer.clone()).await;
102 Ok((id, timer))
103 }
104
105 pub async fn start_recurring<F>(
107 &self,
108 schedule: RecurringSchedule,
109 callback: F,
110 ) -> Result<(u64, Timer), TimerError>
111 where
112 F: TimerCallback + 'static,
113 {
114 let timer = Timer::new_with_runtime(self.runtime.clone(), true);
115 let _ = timer.start_recurring(schedule, callback).await?;
116 let id = self.insert(timer.clone()).await;
117 Ok((id, timer))
118 }
119
120 pub async fn remove(&self, id: u64) -> Option<Timer> {
122 self.timers.write().await.remove(&id)
123 }
124
125 pub async fn contains(&self, id: u64) -> bool {
127 self.timers.read().await.contains_key(&id)
128 }
129
130 pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
132 let timer = self.get(id).await;
133 match timer {
134 Some(timer) => timer.stop().await.map(Some),
135 None => Ok(None),
136 }
137 }
138
139 pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
141 let timer = self.get(id).await;
142 match timer {
143 Some(timer) => timer.cancel().await.map(Some),
144 None => Ok(None),
145 }
146 }
147
148 pub async fn pause(&self, id: u64) -> Result<bool, TimerError> {
150 let timer = self.get(id).await;
151 match timer {
152 Some(timer) => {
153 timer.pause().await?;
154 Ok(true)
155 }
156 None => Ok(false),
157 }
158 }
159
160 pub async fn resume(&self, id: u64) -> Result<bool, TimerError> {
162 let timer = self.get(id).await;
163 match timer {
164 Some(timer) => {
165 timer.resume().await?;
166 Ok(true)
167 }
168 None => Ok(false),
169 }
170 }
171
172 pub async fn stop_all(&self) {
174 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
175 for timer in timers {
176 let _ = timer.stop().await;
177 }
178 }
179
180 pub async fn pause_all(&self) {
182 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
183 for timer in timers {
184 let _ = timer.pause().await;
185 }
186 }
187
188 pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
190 let timers: Vec<(u64, Timer)> = self
191 .timers
192 .read()
193 .await
194 .iter()
195 .map(|(id, timer)| (*id, timer.clone()))
196 .collect();
197
198 let mut outcomes = Vec::with_capacity(timers.len());
199 for (id, timer) in timers {
200 if let Ok(outcome) = timer.join().await {
201 outcomes.push((id, outcome));
202 }
203 }
204
205 outcomes
206 }
207
208 pub async fn cancel_all(&self) {
210 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
211 for timer in timers {
212 let _ = timer.cancel().await;
213 }
214 }
215
216 pub async fn resume_all(&self) {
218 let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
219 for timer in timers {
220 let _ = timer.resume().await;
221 }
222 }
223
224 pub async fn active_ids(&self) -> Vec<u64> {
226 let timers: Vec<(u64, Timer)> = self
227 .timers
228 .read()
229 .await
230 .iter()
231 .map(|(id, timer)| (*id, timer.clone()))
232 .collect();
233
234 let mut active = Vec::new();
235 for (id, timer) in timers {
236 if timer.get_state().await != TimerState::Stopped {
237 active.push(id);
238 }
239 }
240 active
241 }
242
243 pub async fn get(&self, id: u64) -> Option<Timer> {
245 self.timers.read().await.get(&id).cloned()
246 }
247
248 pub async fn snapshot(&self, id: u64) -> Option<RegisteredTimer> {
250 let timer = self.get(id).await?;
251 Some(RegisteredTimer::from_snapshot(id, timer.snapshot().await))
252 }
253
254 pub async fn list(&self) -> Vec<RegisteredTimer> {
256 let timers: Vec<(u64, Timer)> = self
257 .timers
258 .read()
259 .await
260 .iter()
261 .map(|(id, timer)| (*id, timer.clone()))
262 .collect();
263
264 let mut listed = Vec::with_capacity(timers.len());
265 for (id, timer) in timers {
266 listed.push(RegisteredTimer::from_snapshot(id, timer.snapshot().await));
267 }
268 listed
269 }
270
271 pub async fn find_by_label(&self, label: &str) -> Vec<u64> {
273 let snapshots = self.list().await;
274 snapshots
275 .into_iter()
276 .filter(|timer| timer.metadata.label.as_deref() == Some(label))
277 .map(|timer| timer.id)
278 .collect()
279 }
280
281 pub async fn len(&self) -> usize {
283 self.timers.read().await.len()
284 }
285
286 pub async fn is_empty(&self) -> bool {
288 self.len().await == 0
289 }
290
291 pub async fn clear(&self) -> usize {
293 let mut timers = self.timers.write().await;
294 let removed = timers.len();
295 timers.clear();
296 removed
297 }
298}
299
300impl RegisteredTimer {
301 fn from_snapshot(id: u64, snapshot: TimerSnapshot) -> Self {
302 Self {
303 id,
304 state: snapshot.state,
305 interval: snapshot.interval,
306 expiration_count: snapshot.expiration_count,
307 statistics: snapshot.statistics,
308 last_outcome: snapshot.last_outcome,
309 metadata: snapshot.metadata,
310 }
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::timer::TimerFinishReason;
318 use tokio::task::yield_now;
319 use tokio::time::advance;
320
321 async fn settle() {
322 for _ in 0..5 {
323 yield_now().await;
324 }
325 }
326
327 #[tokio::test(flavor = "current_thread", start_paused = true)]
328 async fn registry_start_helpers_are_easy_to_use() {
329 let registry = TimerRegistry::new();
330 let (once_id, once_timer) = registry
331 .start_once(Duration::from_secs(1), || async { Ok(()) })
332 .await
333 .unwrap();
334 let (recurring_id, recurring_timer) = registry
335 .start_recurring(RecurringSchedule::new(Duration::from_secs(2)), || async {
336 Ok(())
337 })
338 .await
339 .unwrap();
340
341 assert_ne!(once_id, recurring_id);
342 assert_eq!(registry.len().await, 2);
343 assert!(registry.get(once_id).await.is_some());
344
345 advance(Duration::from_secs(1)).await;
346 settle().await;
347 assert_eq!(
348 once_timer.join().await.unwrap().reason,
349 crate::timer::TimerFinishReason::Completed
350 );
351
352 let active = registry.active_ids().await;
353 assert!(active.contains(&recurring_id));
354
355 let _ = recurring_timer.cancel().await.unwrap();
356 }
357
358 #[tokio::test(flavor = "current_thread", start_paused = true)]
359 async fn registry_supports_direct_timer_controls() {
360 let registry = TimerRegistry::new();
361 let (timer_id, _timer) = registry
362 .start_once(Duration::from_secs(5), || async { Ok(()) })
363 .await
364 .unwrap();
365
366 assert!(registry.contains(timer_id).await);
367 let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
368 assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
369 assert_eq!(registry.clear().await, 1);
370 assert!(registry.is_empty().await);
371 }
372
373 #[tokio::test(flavor = "current_thread", start_paused = true)]
374 async fn registry_can_pause_and_resume_tracked_timers() {
375 let registry = TimerRegistry::new();
376 let (timer_id, timer) = registry
377 .start_recurring(
378 RecurringSchedule::new(Duration::from_secs(2)).with_expiration_count(1),
379 || async { Ok(()) },
380 )
381 .await
382 .unwrap();
383 settle().await;
384
385 assert!(registry.pause(timer_id).await.unwrap());
386 assert_eq!(timer.get_state().await, TimerState::Paused);
387
388 advance(Duration::from_secs(5)).await;
389 settle().await;
390 assert_eq!(timer.get_statistics().await.execution_count, 0);
391
392 assert!(registry.resume(timer_id).await.unwrap());
393 advance(Duration::from_secs(2)).await;
394 settle().await;
395 assert_eq!(
396 timer.join().await.unwrap().reason,
397 TimerFinishReason::Completed
398 );
399 }
400}