sched_callback/
task.rs

1use std::{future::Future, pin::Pin, sync::Arc, task::{Poll, Waker}, time::{Duration, SystemTime}};
2
3use crate::queue::AsyncRt;
4
5type Callback = Box<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> + Send + 'static>;
6type ArcAsyncMutex<T> = Arc<tokio::sync::Mutex<T>>;
7type ArcSyncMutex<T> = Arc<std::sync::Mutex<T>>;
8
9/// Specifies the schedule type of task.
10/// - SchedType::Timestamp(SystemTime) specifies the SystemTime when the callback will be executed.
11/// - SchedType::Delay(Duration, usize) specifies the interval and count that the callback will be
12/// executed.
13#[derive(Clone, Debug)]
14pub enum SchedType {
15    Timestamp(SystemTime),
16    Delay(Duration, usize),
17}
18
19/// Task struct that implements Future.
20/// Wait until the timestamp and then execute the callback.
21/// `ready` must be called before starting polling of this future.
22pub struct Task {
23    pub(crate) id: Option<usize>,
24    sched_type: SchedType,
25    pub(crate) timestamp: Option<SystemTime>,
26    callback: ArcAsyncMutex<Callback>,
27    _waker: Option<ArcSyncMutex<Waker>>,
28    _rt: Option<AsyncRt>,
29}
30impl Task {
31    /// Creates new task with SchedType and callback.
32    pub fn new(sched_type: SchedType, callback: Callback) -> Self {
33        Self {
34            id: None,
35            sched_type,
36            timestamp: None,
37            callback: Arc::new(tokio::sync::Mutex::new(callback)),
38            _waker: None,
39            _rt: None,
40        }
41    }
42    /// Set timestamp of this future, and handle that executes the async job.
43    pub fn ready(&mut self, rt: tokio::runtime::Handle) {
44        // initialize next timestamp using sched type of task
45        match &mut self.sched_type {
46            SchedType::Timestamp(timestamp) => {
47                match self.timestamp {
48                    Some(_) => self.timestamp = None,
49                    None => self.timestamp = Some(*timestamp),
50                }
51            },
52            SchedType::Delay(dur, count) => {
53                match count {
54                    0 => {
55                        self.timestamp = None; 
56                    },
57                    _ => {
58                        self.timestamp = Some(SystemTime::now() + *dur);
59                        *count -= 1;
60                    }
61                }
62            }
63            
64        }
65        self._rt = Some(rt);
66    }
67}
68impl Clone for Task {
69    fn clone(&self) -> Self {
70        Self {
71            id: self.id,
72            sched_type: self.sched_type.clone(),
73            timestamp: self.timestamp,
74            callback: self.callback.clone(),
75            _waker: self._waker.clone(),
76            _rt: self._rt.clone(),
77        }
78    }
79}
80impl Future for Task {
81    type Output = ();
82
83    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
84        let Some(handle) = self._rt.clone() else { return Poll::Ready(()); };
85        let Some(next_timestamp) = self.timestamp else { return Poll::Ready(()); };
86        if SystemTime::now() >= next_timestamp {
87            let callback = self.callback.clone();
88            handle.spawn(async move {
89                let guard = callback.lock().await;
90                guard().await;
91            });
92            return Poll::Ready(())
93        }    
94        if let Some(waker) = &self._waker {
95            let mut waker = waker.lock().unwrap();
96            if !waker.will_wake(cx.waker()) {
97                *waker = cx.waker().clone();
98            }
99        } else {
100            let waker = Arc::new(std::sync::Mutex::new(cx.waker().clone()));
101            self._waker = Some(waker.clone());
102
103            handle.spawn(async move {
104                let current_time = SystemTime::now();
105                if current_time < next_timestamp {
106                    let diff = next_timestamp.duration_since(current_time).unwrap();
107                    tokio::time::sleep(diff).await;
108                }
109
110                let waker = waker.lock().unwrap();
111                waker.wake_by_ref();
112            });
113        }
114        Poll::Pending
115    }
116}
117