tokio_switching_sleep/
lib.rs

1//! This crate contains two objects: [`SwitchingSleep`] and [`ASwitchingSleep`].
2//!
3//! [`ASwitchingSleep`] is just a wrapper around
4//! [`Arc`](struct@std::sync::Arc)<[`RwLock`](struct@tokio::sync::RwLock)<[`SwitchingSleep`]>>.
5//!
6//! They are a [`tokio::time::Sleep`](struct@tokio::time::Sleep) with a
7//! switchable state. When you call the [`start`] method a [`Sleep`] is created,
8//! when you call the [`stop`] one the current [`Sleep`] is dropped. So calling
9//! [`start`] will reset the timer.
10//!
11//! The timer will complete after the `duration` time since
12//! [`start`] method is called (or [`new_start`], [`new`] + [`start`]).
13//!
14//! [SwitchingSleep]: struct@SwitchingSleep
15//! [Sleep]: struct@tokio::time::Sleep
16//! [`start`]: SwitchingSleep::start()
17//! [`stop`]: SwitchingSleep::stop()
18//! [`new_start`]: SwitchingSleep::new_start()
19//! [`new`]: SwitchingSleep::new()
20
21use std::{
22    fmt::Debug,
23    future::Future,
24    pin::Pin,
25    sync::Arc,
26    task::{Context, Poll},
27    time::Duration,
28};
29
30use tokio::{
31    sync::{broadcast, RwLock},
32    time::{sleep, Sleep},
33};
34
35/// The [`!Sync`][trait@std::marker::Sync] one.
36#[derive(Debug)]
37pub struct SwitchingSleep {
38    period: Duration,
39    tx: broadcast::Sender<()>,
40    rx: broadcast::Receiver<()>,
41    sleeper: Option<Sleep>,
42}
43
44impl Unpin for SwitchingSleep {}
45
46impl SwitchingSleep {
47    /// Create a new [`SwitchingSleep`] and doesn't start the timer.
48    pub fn new(period: Duration) -> Self {
49        let (tx, rx) = broadcast::channel(10);
50
51        Self {
52            period,
53            tx,
54            rx,
55            sleeper: None,
56        }
57    }
58
59    /// Create a new [`SwitchingSleep`] and start the timer.
60    pub fn new_start(period: Duration) -> Self {
61        let mut me = Self::new(period);
62        me.start();
63        me
64    }
65
66    /// Start the timer. Reset if already started.
67    pub fn start(&mut self) {
68        if !self.is_elapsed() {
69            self.stop();
70
71            self.sleeper = Some(sleep(self.period));
72            self.tx.send(()).unwrap();
73        }
74    }
75
76    /// Stop the timer. It doesn nothing if already stopped.
77    pub fn stop(&mut self) {
78        if !self.is_elapsed() {
79            match self.sleeper.take() {
80                Some(_) => {
81                    self.tx.send(()).unwrap();
82                }
83                None => (),
84            }
85        }
86    }
87
88    /// Check if the timer (if any) is elapsed.
89    pub fn is_elapsed(&self) -> bool {
90        self.sleeper.is_some() && (&self.sleeper).as_ref().unwrap().is_elapsed()
91    }
92}
93
94unsafe impl Send for SwitchingSleep {}
95
96impl Future for SwitchingSleep {
97    type Output = ();
98
99    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
100        unsafe {
101            let me = Pin::get_unchecked_mut(self);
102
103            if me.is_elapsed() {
104                return Poll::Ready(());
105            }
106
107            let sleeper = match me.sleeper {
108                Some(ref mut sleeper) => {
109                    let sleeper = Pin::new_unchecked(sleeper);
110
111                    Some(sleeper.poll(cx))
112                }
113                None => None,
114            };
115            let mut recv = me.rx.recv();
116            let recv = Pin::new_unchecked(&mut recv);
117            let _ = recv.poll(cx);
118
119            if let Some(Poll::Ready(_)) = sleeper {
120                Poll::Ready(())
121            } else {
122                Poll::Pending
123            }
124        }
125    }
126}
127
128/// The [`Sync`][trait@std::marker::Sync] one.
129#[derive(Debug)]
130pub struct ASwitchingSleep(Arc<RwLock<SwitchingSleep>>);
131
132impl ASwitchingSleep {
133    /// Create a new [`ASwitchingSleep`] and doesn't start the timer.
134    pub fn new(period: Duration) -> Self {
135        Self(Arc::new(RwLock::new(SwitchingSleep::new(period))))
136    }
137
138    /// Create a new [`ASwitchingSleep`] and start the timer.
139    pub async fn new_start(period: Duration) -> Self {
140        let me = Self::new(period);
141        me.start().await;
142        me
143    }
144
145    /// Start the timer. Reset if already started.
146    pub async fn start(&self) {
147        let mut inner = self.0.write().await;
148        inner.start()
149    }
150
151    /// Stop the timer. It doesn nothing if already stopped.
152    pub async fn stop(&self) {
153        let mut inner = self.0.write().await;
154        inner.stop()
155    }
156
157    /// Check if the timer (if any) is elapsed.
158    pub async fn is_elapsed(&self) -> bool {
159        let inner = self.0.read().await;
160        inner.is_elapsed()
161    }
162}
163
164unsafe impl Send for ASwitchingSleep {}
165unsafe impl Sync for ASwitchingSleep {}
166impl Unpin for ASwitchingSleep {}
167
168impl Clone for ASwitchingSleep {
169    fn clone(&self) -> Self {
170        Self(self.0.clone())
171    }
172}
173
174impl Future for ASwitchingSleep {
175    type Output = ();
176
177    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
178        unsafe {
179            let me = Pin::get_unchecked_mut(self);
180
181            let mut inner = me.0.write();
182            let inner = Pin::new_unchecked(&mut inner);
183
184            match inner.poll(cx) {
185                Poll::Pending => Poll::Pending,
186                Poll::Ready(mut inner) => Pin::new_unchecked(&mut *inner).poll(cx),
187            }
188        }
189    }
190}
191
192#[cfg(test)]
193mod test {
194    use super::*;
195    use std::time::Duration;
196    use tokio::{
197        select,
198        time::{sleep, Instant},
199    };
200
201    #[tokio::test]
202    async fn it_works() {
203        let mut sleeper = ASwitchingSleep::new(Duration::from_secs(3));
204
205        let start = Instant::now();
206
207        let mut task = {
208            let sleeper = sleeper.clone();
209            tokio::task::spawn(async move {
210                sleep(Duration::from_secs(5)).await;
211
212                assert_eq!(sleeper.is_elapsed().await, false);
213
214                sleeper.start().await;
215
216                sleep(Duration::from_secs(2)).await;
217
218                assert_eq!(sleeper.is_elapsed().await, false);
219
220                sleeper.stop().await;
221
222                sleep(Duration::from_secs(2)).await;
223
224                assert_eq!(sleeper.is_elapsed().await, false);
225
226                sleeper.start().await;
227
228                sleep(Duration::from_secs(2)).await;
229
230                assert_eq!(sleeper.is_elapsed().await, false);
231            })
232        };
233
234        loop {
235            select! {
236                _ = &mut task => {
237                    loop {
238                        select! {
239                            _ = &mut sleeper => {
240                                break;
241                            }
242                        }
243                    }
244                    break;
245                },
246                _ = &mut sleeper => break,
247            }
248        }
249
250        let stop = Instant::now();
251        let diff = stop - start;
252
253        assert_eq!(sleeper.is_elapsed().await, true);
254        assert_eq!(diff.as_secs(), 12);
255    }
256}