1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#![forbid(unsafe_code)]

use crate::{schedule_wake, TimerThreadNotStarted};
use core::future::Future;
use core::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::time::Instant;

/// A future that completes after the specified time.
///
/// It is returned by [`sleep_until`] and [`sleep_for`].
#[must_use = "futures stay idle unless you await them"]
pub struct SleepFuture {
    deadline: std::time::Instant,
    waker: Arc<Mutex<Option<Waker>>>,
}
impl SleepFuture {
    pub fn new(deadline: Instant) -> Self {
        Self {
            deadline,
            waker: Arc::new(Mutex::new(None)),
        }
    }
}
impl Future for SleepFuture {
    type Output = Result<(), TimerThreadNotStarted>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        if self.deadline < std::time::Instant::now() {
            return Poll::Ready(Ok(()));
        }
        let old_waker = self.waker.lock().unwrap().replace(cx.waker().clone());
        if old_waker.is_none() {
            schedule_wake(self.deadline, self.waker.clone())?;
        }
        Poll::Pending
    }
}

/// Returns after `deadline`.
///
/// Panics if [`start_timer_thread()`](fn.start_timer_thread.html) has not been called.
/// If you need to handle this error, use [`SleepFuture::new`](struct.SleepFuture.html).
pub async fn sleep_until(deadline: std::time::Instant) {
    SleepFuture::new(deadline).await.unwrap()
}

/// Returns `duration` time from now.
///
/// Panics if [`start_timer_thread()`](fn.start_timer_thread.html) has not been called.
/// If you need to handle this error, use [`SleepFuture::new`](struct.SleepFuture.html).
pub async fn sleep_for(duration: core::time::Duration) {
    SleepFuture::new(Instant::now() + duration).await.unwrap()
}

#[cfg(test)]
mod tests {
    use super::super::{expect_elapsed, start_timer_thread, FakeWaker};
    use super::*;
    use core::future::Future;
    use core::sync::atomic::{AtomicBool, Ordering};
    use core::task::{Context, Poll};
    use core::time::Duration;
    use rusty_fork::rusty_fork_test;
    use std::convert::TryFrom;
    use std::sync::Arc;
    use std::time::Instant;

    fn timer_thread_not_started_inner() {
        assert!(std::panic::catch_unwind(|| {
            safina_executor::block_on(async move { sleep_for(Duration::from_millis(10)).await })
        })
        .unwrap_err()
        .downcast::<String>()
        .unwrap()
        .contains("TimerThreadNotStarted"));

        assert!(std::panic::catch_unwind(|| {
            safina_executor::block_on(async move {
                sleep_until(Instant::now() + Duration::from_millis(10)).await
            })
        })
        .unwrap_err()
        .downcast::<String>()
        .unwrap()
        .contains("TimerThreadNotStarted"));

        for _ in 0..2 {
            start_timer_thread();
            safina_executor::block_on(async move {
                sleep_for(Duration::from_millis(10)).await;
            });
            safina_executor::block_on(async move {
                sleep_until(Instant::now() + Duration::from_millis(10)).await;
            });
        }
    }
    rusty_fork_test! {
        #[test]
        fn timer_thread_not_started() {
            timer_thread_not_started_inner()
        }
    }

    #[test]
    pub fn test_sleep_for() {
        start_timer_thread();
        let before = Instant::now();
        safina_executor::block_on(async move {
            sleep_for(Duration::from_millis(100)).await;
        });
        expect_elapsed(before, 100..200);
    }

    #[test]
    pub fn test_sleep_for_zero() {
        start_timer_thread();
        let before = Instant::now();
        safina_executor::block_on(async move {
            sleep_for(Duration::from_secs(0)).await;
        });
        expect_elapsed(before, 0..90);
    }

    #[test]
    pub fn test_sleep_until() {
        start_timer_thread();
        let before = Instant::now();
        let deadline = before + Duration::from_millis(100);
        safina_executor::block_on(async move {
            sleep_until(deadline).await;
        });
        expect_elapsed(before, 100..200);
    }

    #[test]
    pub fn test_sleep_until_past() {
        start_timer_thread();
        let before = Instant::now();
        let deadline = before - Duration::from_millis(100);
        safina_executor::block_on(async move {
            sleep_until(deadline).await;
        });
        expect_elapsed(before, 0..90);
    }

    #[test]
    pub fn test_multi_sleep() {
        let executor = safina_executor::Executor::new(1);
        start_timer_thread();
        let before = Instant::now();
        let mut expected_durations_ms = [200_u64, 100, 0, 400, 500, 300];
        let receiver = {
            let (sender, receiver) = std::sync::mpsc::channel::<u64>();
            for duration_ms in &expected_durations_ms {
                let before_clone = before;
                let sender_clone = sender.clone();
                let duration_ms_copy = *duration_ms;
                executor.spawn(async move {
                    // println!("{} sleeping until {} ms", n, duration_ms);
                    let deadline = before_clone + Duration::from_millis(duration_ms_copy);
                    sleep_until(deadline).await;
                    let elapsed = before_clone.elapsed();
                    let elapsed_u64 = u64::try_from(elapsed.as_millis()).unwrap();
                    // println!("{} finished sleeping, sending {:?}, {} ms", n, elapsed, elapsed_u64);
                    sender_clone.send(elapsed_u64).unwrap();
                    // println!("{} done", n);
                });
            }
            receiver
        };
        let mut actual_durations_ms: Vec<u64> = Vec::new();
        for duration_ms in receiver {
            actual_durations_ms.push(duration_ms);
            // println!("received duration {:?} ms", actual_durations_ms);
        }
        actual_durations_ms.sort_unstable();
        // println!("actual durations {:?} ms", actual_durations_ms);
        expected_durations_ms.sort_unstable();
        // println!("expected durations {:?} ms", expected_duration_ms);
        for n in 0..expected_durations_ms.len() {
            let actual = actual_durations_ms[n];
            let expected = expected_durations_ms[n];
            let range = expected..(expected + 90);
            // println!("{:?} ms actual, expected range {:?}", actual, range);
            if !range.contains(&actual) {
                panic!("{:?} ms actual, out of range {:?}", actual, range);
            }
        }
    }

    #[test]
    pub fn should_use_most_recent_waker_passed_to_poll() {
        // "Note that on multiple calls to poll, only the Waker from the Context
        // passed to the most recent call should be scheduled to receive a wakeup."
        // https://doc.rust-lang.org/stable/std/future/trait.Future.html#tymethod.poll
        start_timer_thread();
        let deadline = Instant::now() + Duration::from_millis(100);
        let mut fut = Box::pin(async move {
            sleep_until(deadline).await;
        });
        let waker1_called = Arc::new(AtomicBool::new(false));
        {
            let waker1 = FakeWaker::new(&waker1_called).into_waker();
            let mut cx = Context::from_waker(&waker1);
            assert_eq!(Poll::Pending, fut.as_mut().poll(&mut cx));
        }
        let waker2_called = Arc::new(AtomicBool::new(false));
        {
            let waker2 = FakeWaker::new(&waker2_called).into_waker();
            let mut cx = Context::from_waker(&waker2);
            assert_eq!(Poll::Pending, fut.as_mut().poll(&mut cx));
        }
        std::thread::sleep(Duration::from_millis(200));
        {
            let waker3_called = Arc::new(AtomicBool::new(true /* should never get called */));
            let waker3 = FakeWaker::new(&waker3_called).into_waker();
            let mut cx = Context::from_waker(&waker3);
            assert_eq!(Poll::Ready(()), fut.as_mut().poll(&mut cx));
        }
        assert!(!waker1_called.load(Ordering::Acquire));
        assert!(waker2_called.load(Ordering::Acquire));
    }
}