Skip to main content

timer_lib/
registry.rs

1use std::collections::HashMap;
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9use tokio::time::Instant;
10
11use crate::errors::TimerError;
12use crate::timer::driver::RuntimeHandle;
13use crate::timer::{
14    RecurringSchedule, Timer, TimerCallback, TimerMetadata, TimerOutcome, TimerSnapshot, TimerState,
15};
16
17/// Snapshot of a timer tracked by the registry.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct RegisteredTimer {
20    /// Registry identifier for the timer.
21    pub id: u64,
22    /// Current or most recent timer state.
23    pub state: TimerState,
24    /// Effective timer interval.
25    pub interval: Duration,
26    /// Optional recurring execution limit.
27    pub expiration_count: Option<usize>,
28    /// Run statistics captured from the timer.
29    pub statistics: crate::timer::TimerStatistics,
30    /// Most recent completed outcome, if any.
31    pub last_outcome: Option<TimerOutcome>,
32    /// Metadata associated with the timer.
33    pub metadata: TimerMetadata,
34}
35
36/// A registry for tracking timers by identifier.
37#[derive(Clone, Default)]
38pub struct TimerRegistry {
39    timers: Arc<RwLock<HashMap<u64, Timer>>>,
40    next_id: Arc<AtomicU64>,
41    runtime: RuntimeHandle,
42}
43
44impl TimerRegistry {
45    /// Creates a new timer registry.
46    pub fn new() -> Self {
47        Self {
48            timers: Arc::new(RwLock::new(HashMap::new())),
49            next_id: Arc::new(AtomicU64::new(0)),
50            runtime: RuntimeHandle::default(),
51        }
52    }
53
54    /// Creates a new registry backed by a manually-driven test runtime.
55    #[cfg(feature = "test-util")]
56    pub fn new_mocked() -> (Self, crate::timer::MockRuntime) {
57        let runtime = crate::timer::MockRuntime::new();
58        (
59            Self {
60                timers: Arc::new(RwLock::new(HashMap::new())),
61                next_id: Arc::new(AtomicU64::new(0)),
62                runtime: runtime.handle(),
63            },
64            runtime,
65        )
66    }
67
68    /// Inserts an existing timer and returns its identifier.
69    pub async fn insert(&self, timer: Timer) -> u64 {
70        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
71        self.timers.write().await.insert(id, timer);
72        id
73    }
74
75    /// Starts and registers a one-time timer.
76    pub async fn start_once<F>(
77        &self,
78        delay: Duration,
79        callback: F,
80    ) -> Result<(u64, Timer), TimerError>
81    where
82        F: TimerCallback + 'static,
83    {
84        let timer = Timer::new_with_runtime(self.runtime.clone(), true);
85        let _ = timer.start_once(delay, callback).await?;
86        let id = self.insert(timer.clone()).await;
87        Ok((id, timer))
88    }
89
90    /// Starts and registers a one-time timer at a deadline.
91    pub async fn start_at<F>(
92        &self,
93        deadline: Instant,
94        callback: F,
95    ) -> Result<(u64, Timer), TimerError>
96    where
97        F: TimerCallback + 'static,
98    {
99        let timer = Timer::new_with_runtime(self.runtime.clone(), true);
100        let _ = timer.start_at(deadline, callback).await?;
101        let id = self.insert(timer.clone()).await;
102        Ok((id, timer))
103    }
104
105    /// Starts and registers a recurring timer.
106    pub async fn start_recurring<F>(
107        &self,
108        schedule: RecurringSchedule,
109        callback: F,
110    ) -> Result<(u64, Timer), TimerError>
111    where
112        F: TimerCallback + 'static,
113    {
114        let timer = Timer::new_with_runtime(self.runtime.clone(), true);
115        let _ = timer.start_recurring(schedule, callback).await?;
116        let id = self.insert(timer.clone()).await;
117        Ok((id, timer))
118    }
119
120    /// Removes a timer from the registry and returns it.
121    pub async fn remove(&self, id: u64) -> Option<Timer> {
122        self.timers.write().await.remove(&id)
123    }
124
125    /// Returns true when the registry tracks the given timer identifier.
126    pub async fn contains(&self, id: u64) -> bool {
127        self.timers.read().await.contains_key(&id)
128    }
129
130    /// Stops a timer by identifier when it exists.
131    pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
132        let timer = self.get(id).await;
133        match timer {
134            Some(timer) => timer.stop().await.map(Some),
135            None => Ok(None),
136        }
137    }
138
139    /// Cancels a timer by identifier when it exists.
140    pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
141        let timer = self.get(id).await;
142        match timer {
143            Some(timer) => timer.cancel().await.map(Some),
144            None => Ok(None),
145        }
146    }
147
148    /// Pauses a timer by identifier when it exists.
149    pub async fn pause(&self, id: u64) -> Result<bool, TimerError> {
150        let timer = self.get(id).await;
151        match timer {
152            Some(timer) => {
153                timer.pause().await?;
154                Ok(true)
155            }
156            None => Ok(false),
157        }
158    }
159
160    /// Resumes a timer by identifier when it exists.
161    pub async fn resume(&self, id: u64) -> Result<bool, TimerError> {
162        let timer = self.get(id).await;
163        match timer {
164            Some(timer) => {
165                timer.resume().await?;
166                Ok(true)
167            }
168            None => Ok(false),
169        }
170    }
171
172    /// Stops all timers currently tracked by the registry.
173    pub async fn stop_all(&self) {
174        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
175        for timer in timers {
176            let _ = timer.stop().await;
177        }
178    }
179
180    /// Pauses all running timers currently tracked by the registry.
181    pub async fn pause_all(&self) {
182        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
183        for timer in timers {
184            let _ = timer.pause().await;
185        }
186    }
187
188    /// Waits for all tracked timers that have a joinable outcome.
189    pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
190        let timers: Vec<(u64, Timer)> = self
191            .timers
192            .read()
193            .await
194            .iter()
195            .map(|(id, timer)| (*id, timer.clone()))
196            .collect();
197
198        let mut outcomes = Vec::with_capacity(timers.len());
199        for (id, timer) in timers {
200            if let Ok(outcome) = timer.join().await {
201                outcomes.push((id, outcome));
202            }
203        }
204
205        outcomes
206    }
207
208    /// Cancels all timers currently tracked by the registry.
209    pub async fn cancel_all(&self) {
210        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
211        for timer in timers {
212            let _ = timer.cancel().await;
213        }
214    }
215
216    /// Resumes all paused timers currently tracked by the registry.
217    pub async fn resume_all(&self) {
218        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
219        for timer in timers {
220            let _ = timer.resume().await;
221        }
222    }
223
224    /// Lists all active timers.
225    pub async fn active_ids(&self) -> Vec<u64> {
226        let timers: Vec<(u64, Timer)> = self
227            .timers
228            .read()
229            .await
230            .iter()
231            .map(|(id, timer)| (*id, timer.clone()))
232            .collect();
233
234        let mut active = Vec::new();
235        for (id, timer) in timers {
236            if timer.get_state().await != TimerState::Stopped {
237                active.push(id);
238            }
239        }
240        active
241    }
242
243    /// Retrieves a timer by ID.
244    pub async fn get(&self, id: u64) -> Option<Timer> {
245        self.timers.read().await.get(&id).cloned()
246    }
247
248    /// Returns a snapshot of a tracked timer by identifier.
249    pub async fn snapshot(&self, id: u64) -> Option<RegisteredTimer> {
250        let timer = self.get(id).await?;
251        Some(RegisteredTimer::from_snapshot(id, timer.snapshot().await))
252    }
253
254    /// Lists snapshots for all tracked timers.
255    pub async fn list(&self) -> Vec<RegisteredTimer> {
256        let timers: Vec<(u64, Timer)> = self
257            .timers
258            .read()
259            .await
260            .iter()
261            .map(|(id, timer)| (*id, timer.clone()))
262            .collect();
263
264        let mut listed = Vec::with_capacity(timers.len());
265        for (id, timer) in timers {
266            listed.push(RegisteredTimer::from_snapshot(id, timer.snapshot().await));
267        }
268        listed
269    }
270
271    /// Returns the identifiers for timers carrying a matching label.
272    pub async fn find_by_label(&self, label: &str) -> Vec<u64> {
273        let snapshots = self.list().await;
274        snapshots
275            .into_iter()
276            .filter(|timer| timer.metadata.label.as_deref() == Some(label))
277            .map(|timer| timer.id)
278            .collect()
279    }
280
281    /// Returns the number of tracked timers.
282    pub async fn len(&self) -> usize {
283        self.timers.read().await.len()
284    }
285
286    /// Returns true when the registry is empty.
287    pub async fn is_empty(&self) -> bool {
288        self.len().await == 0
289    }
290
291    /// Removes all tracked timers and returns the number removed.
292    pub async fn clear(&self) -> usize {
293        let mut timers = self.timers.write().await;
294        let removed = timers.len();
295        timers.clear();
296        removed
297    }
298}
299
300impl RegisteredTimer {
301    fn from_snapshot(id: u64, snapshot: TimerSnapshot) -> Self {
302        Self {
303            id,
304            state: snapshot.state,
305            interval: snapshot.interval,
306            expiration_count: snapshot.expiration_count,
307            statistics: snapshot.statistics,
308            last_outcome: snapshot.last_outcome,
309            metadata: snapshot.metadata,
310        }
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::timer::TimerFinishReason;
318    use tokio::task::yield_now;
319    use tokio::time::advance;
320
321    async fn settle() {
322        for _ in 0..5 {
323            yield_now().await;
324        }
325    }
326
327    #[tokio::test(flavor = "current_thread", start_paused = true)]
328    async fn registry_start_helpers_are_easy_to_use() {
329        let registry = TimerRegistry::new();
330        let (once_id, once_timer) = registry
331            .start_once(Duration::from_secs(1), || async { Ok(()) })
332            .await
333            .unwrap();
334        let (recurring_id, recurring_timer) = registry
335            .start_recurring(RecurringSchedule::new(Duration::from_secs(2)), || async {
336                Ok(())
337            })
338            .await
339            .unwrap();
340
341        assert_ne!(once_id, recurring_id);
342        assert_eq!(registry.len().await, 2);
343        assert!(registry.get(once_id).await.is_some());
344
345        advance(Duration::from_secs(1)).await;
346        settle().await;
347        assert_eq!(
348            once_timer.join().await.unwrap().reason,
349            crate::timer::TimerFinishReason::Completed
350        );
351
352        let active = registry.active_ids().await;
353        assert!(active.contains(&recurring_id));
354
355        let _ = recurring_timer.cancel().await.unwrap();
356    }
357
358    #[tokio::test(flavor = "current_thread", start_paused = true)]
359    async fn registry_supports_direct_timer_controls() {
360        let registry = TimerRegistry::new();
361        let (timer_id, _timer) = registry
362            .start_once(Duration::from_secs(5), || async { Ok(()) })
363            .await
364            .unwrap();
365
366        assert!(registry.contains(timer_id).await);
367        let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
368        assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
369        assert_eq!(registry.clear().await, 1);
370        assert!(registry.is_empty().await);
371    }
372
373    #[tokio::test(flavor = "current_thread", start_paused = true)]
374    async fn registry_can_pause_and_resume_tracked_timers() {
375        let registry = TimerRegistry::new();
376        let (timer_id, timer) = registry
377            .start_recurring(
378                RecurringSchedule::new(Duration::from_secs(2)).with_expiration_count(1),
379                || async { Ok(()) },
380            )
381            .await
382            .unwrap();
383        settle().await;
384
385        assert!(registry.pause(timer_id).await.unwrap());
386        assert_eq!(timer.get_state().await, TimerState::Paused);
387
388        advance(Duration::from_secs(5)).await;
389        settle().await;
390        assert_eq!(timer.get_statistics().await.execution_count, 0);
391
392        assert!(registry.resume(timer_id).await.unwrap());
393        advance(Duration::from_secs(2)).await;
394        settle().await;
395        assert_eq!(
396            timer.join().await.unwrap().reason,
397            TimerFinishReason::Completed
398        );
399    }
400}