1use std::future::Future;
2use std::io::Error as IoError;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Instant;
6
7use crate::{ClockId, TimerFd};
8use futures_core::ready;
9use timerfd::{SetTimeFlags, TimerState};
10use tokio::io::{AsyncRead, ReadBuf};
11
12pub struct Delay {
16 timerfd: TimerFd,
17 deadline: Instant,
18 initialized: bool,
19}
20
21impl Delay {
22 pub fn new(deadline: Instant) -> Result<Self, IoError> {
24 let timerfd = TimerFd::new(ClockId::Monotonic)?;
25 Ok(Delay {
26 timerfd,
27 deadline,
28 initialized: false,
29 })
30 }
31
32 pub fn deadline(&self) -> Instant {
34 self.deadline
35 }
36
37 pub fn is_elapsed(&self) -> bool {
39 self.deadline > Instant::now()
40 }
41
42 pub fn reset(&mut self, deadline: Instant) {
44 self.deadline = deadline;
45 self.initialized = false;
46 }
47}
48
49impl Future for Delay {
50 type Output = Result<(), IoError>;
51
52 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
53 if !self.initialized {
54 let now = Instant::now();
55 let duration = if self.deadline > now {
56 self.deadline - now
57 } else {
58 return Poll::Ready(Ok(()));
59 };
60 self.timerfd
61 .set_state(TimerState::Oneshot(duration), SetTimeFlags::Default);
62 self.initialized = true;
63 }
64 let mut buf = [0u8; 8];
65 let mut buf = ReadBuf::new(&mut buf);
66 ready!(Pin::new(&mut self.as_mut().timerfd).poll_read(cx, &mut buf)?);
67 Poll::Ready(Ok(()))
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use std::time::{Duration, Instant};
75
76 #[tokio::test]
77 async fn delay_zero_duration() -> Result<(), std::io::Error> {
78 let now = Instant::now();
79 let delay = Delay::new(Instant::now())?;
80 delay.await?;
81 let elapsed = now.elapsed();
82 println!("{:?}", elapsed);
83 assert!(elapsed < Duration::from_millis(1));
84 Ok(())
85 }
86
87 #[tokio::test]
88 async fn delay_works() {
89 let now = Instant::now();
90 let delay = Delay::new(now + Duration::from_micros(10)).unwrap();
91 delay.await.unwrap();
92 let elapsed = now.elapsed();
93 println!("{:?}", elapsed);
94 assert!(elapsed < Duration::from_millis(1));
95 }
96}