sched_callback/
queue.rs

1use std::{collections::BTreeSet, sync::Arc, time::SystemTime};
2
3use async_recursion::async_recursion;
4use tokio::{sync::{mpsc, MutexGuard}, task::JoinHandle};
5
6use crate::{message::MessageType, task::Task};
7
8pub(crate) type AsyncRt = tokio::runtime::Handle;
9
10struct TaskHandle {
11    task: Task,
12    handle: JoinHandle<()>,
13}
14struct _EventQueue {
15    tid: tokio::sync::Mutex<usize>,
16    tasks: tokio::sync::Mutex<Vec<Task>>,
17    running: tokio::sync::Mutex<Option<TaskHandle>>,
18    tasks_cancelled: tokio::sync::Mutex<BTreeSet<usize>>,
19    tx: mpsc::Sender<MessageType>,
20    _rt: AsyncRt,
21}
22impl _EventQueue {
23    fn new(rt: AsyncRt, tx: mpsc::Sender<MessageType>) -> Self {
24        Self {
25            tid: tokio::sync::Mutex::new(1),
26            tasks: tokio::sync::Mutex::new(Vec::new()),
27            running: tokio::sync::Mutex::new(None),
28            tasks_cancelled: tokio::sync::Mutex::new(BTreeSet::new()),
29            tx,
30            _rt: rt,
31        }
32    }
33    async fn lock_tasks(&self) -> MutexGuard<Vec<Task>> {
34        self.tasks.lock().await 
35    }
36    async fn lock_running(&self) -> MutexGuard<Option<TaskHandle>> {
37        self.running.lock().await 
38    }
39    async fn lock_tasks_cancelled(&self) -> MutexGuard<BTreeSet<usize>>{
40        self.tasks_cancelled.lock().await
41    }
42    async fn get_tid(&self) -> usize {
43        let mut guard = self.tid.lock().await; 
44        let tid = *guard;
45        *guard += 1;
46        tid
47    }
48}
49async fn _run_task(sq: Arc<_EventQueue>, task: &Task) -> Option<JoinHandle<()>> {
50    let sqc = sq.clone();
51    let task = task.clone();
52    let task_id = task.id.expect("_run_task id not assigned to task");
53    // check task is cancelled
54    let mut tasks_cancelled_guard = sq.lock_tasks_cancelled().await;
55    if tasks_cancelled_guard.remove(&task.id.unwrap()) {
56        sqc.tx.send(MessageType::Cancel(task_id, SystemTime::now())).await.expect("_run_task message sending failed");
57        return None;
58    }
59    
60    Some(sq._rt.spawn(async move {
61        sqc.tx.send(MessageType::WaitStart(task_id, SystemTime::now())).await.expect("_run_task message sending failed");
62        task.clone().await;
63        sqc.tx.send(MessageType::WaitEnd(task_id, SystemTime::now())).await.expect("_run_task message sending failed");
64        _add(sqc.clone(), task).await;
65        _next(sqc.clone()).await;
66    }))
67}
68
69#[async_recursion]
70async fn _next(sq: Arc<_EventQueue>) {
71    let mut tasks_guard = sq.lock_tasks().await; 
72    let mut running_guard = sq.lock_running().await;
73    *running_guard = None;
74
75    if let Some(task) = tasks_guard.pop() {
76        let handle = _run_task(sq.clone(), &task).await;
77        if let Some(handle) = handle {
78            *running_guard = Some(TaskHandle {
79                task,
80                handle
81            });
82        }
83    }
84}
85#[async_recursion]
86async fn _add(sq: Arc<_EventQueue>, mut task: Task) -> Option<usize> {
87    let mut tasks_guard = sq.tasks.lock().await; 
88    let mut running_guard = sq.running.lock().await;
89
90    // assign task id if it's None
91    let task_id: usize;
92    match task.id {
93        Some(id) => task_id = id,
94        None => {
95            task_id = sq.get_tid().await;
96            task.id = Some(task_id);
97        }
98
99    }
100    
101    // initialize timestamp of task
102    task.ready(sq._rt.clone());
103    
104    // if timestamp is None, no more schedule is need for the task
105    match task.timestamp {
106        Some(_) => {
107            sq.tx.send(MessageType::Add(task_id, SystemTime::now())).await.expect("_add: message sending failed");
108        },
109        None => {
110            return None;
111        }
112    }
113
114    let taskhandle = running_guard.take();
115    if let Some(t) = taskhandle {
116        let Some(running_timestamp) = t.task.timestamp else { return None; };
117        let Some(new_timestamp) = task.timestamp else { return None; };
118
119        // if new task timestamp is earlier than the running one, abort waiting task;
120        if new_timestamp < running_timestamp {
121            t.handle.abort();
122            tasks_guard.push(t.task);
123            sq.tx.send(MessageType::Abort(task_id, SystemTime::now())).await.expect("_add: message sending failed");
124        }
125        else {
126            *running_guard = Some(t);
127        }
128    }
129
130    tasks_guard.push(task);
131    tasks_guard.sort_by(|a, b| {
132        b.timestamp.unwrap().cmp(&a.timestamp.unwrap())
133    });
134
135    if running_guard.is_none() {
136        if let Some(task) = tasks_guard.pop() {
137            let handle = _run_task(sq.clone(), &task).await;
138            if let Some(handle) = handle {
139                *running_guard = Some(TaskHandle {
140                    task,
141                    handle
142                });
143            }
144        }
145    }
146
147    return Some(task_id);
148}
149
150async fn _cancel(sq: Arc<_EventQueue>, id: usize) -> bool {
151    let mut tasks_cancelled_guard = sq.lock_tasks_cancelled().await;
152    tasks_cancelled_guard.insert(id)
153}
154
155#[derive(Clone)]
156pub struct SchedQueue {
157    eq: Arc<_EventQueue>,
158}
159impl SchedQueue {
160    /// Create new scheduler.
161    /// Use tokio runtime of the context that this function is called.
162    pub fn new() -> (Self, mpsc::Receiver<MessageType>) {
163        let (tx, rx) = mpsc::channel(1000);
164        (Self {
165            eq: Arc::new(_EventQueue::new(tokio::runtime::Handle::current(), tx)),
166        }, rx)
167    }
168    /// Schedule task.
169    /// After calling this function, scheduler will automatically start executing its tasks.
170    pub async fn add(&self, task: Task) -> Option<usize> {
171        _add(self.eq.clone(), task).await
172    }
173
174    pub async fn cancel(&self, task_id: usize) -> bool {
175        _cancel(self.eq.clone(), task_id).await
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use std::{sync::Arc, time::{Duration, SystemTime}};
182    use tokio::sync::Mutex;
183    use crate::{queue::SchedQueue, task::{SchedType, Task}};
184
185    #[tokio::test]
186    async fn timestamp_test() {
187        let (sq, _rx) = SchedQueue::new();
188
189        let reserved_time = SystemTime::now() + Duration::from_millis(500);
190        let executed_time = Arc::new(Mutex::new(SystemTime::now()));
191        
192        let ex = executed_time.clone();
193        let _ = sq.add(Task::new(SchedType::Timestamp(reserved_time), Box::new(move ||{
194            let ex = ex.clone();
195            Box::pin(async move {
196                let mut guard = ex.lock().await;
197                *guard = SystemTime::now();
198            })
199        }))).await;
200        tokio::time::sleep(Duration::from_secs(1)).await;
201        let guard = executed_time.lock().await;
202        
203        let diff = guard.duration_since(reserved_time).unwrap();
204        // assert that difference is < 10ms
205        assert!(diff < Duration::from_millis(10));
206    }
207    #[tokio::test]
208    async fn delay_test() {
209        let (sq, _rx) = SchedQueue::new();
210        let order = Arc::new(Mutex::new(Vec::new()));
211
212        for i in 0..10 {
213            let order = order.clone();
214            let _ = sq.add(Task::new(SchedType::Delay(Duration::from_millis(101-i), 2), Box::new(move || {
215                let order = order.clone();
216                Box::pin(async move {
217                    let mut guard = order.lock().await;
218                    guard.push(i);
219                })
220            }))).await;
221
222        }
223        tokio::time::sleep(Duration::from_secs(1)).await;
224        let guard = order.lock().await;
225        let expected = [9,8,7,6,5,4,3,2,1,0,9,8,7,6,5,4,3,2,1,0];
226        assert_eq!(guard.len(), expected.len());
227        for (e, r) in expected.iter().zip(&*guard) {
228            assert_eq!(e, r);
229        }
230
231    }
232    #[tokio::test]
233    async fn cancel_test() {
234        let (sq, _rx) = SchedQueue::new();
235        let order = Arc::new(Mutex::new(Vec::new()));
236
237        for i in 0..10 {
238            let order = order.clone();
239            let sqc = sq.clone();
240            let _ = sq.add(Task::new(SchedType::Delay(Duration::from_millis(101-i), 2), Box::new(move || {
241                let order = order.clone();
242                let sqc = sqc.clone();
243                Box::pin(async move {
244                    sqc.cancel((i+1).try_into().unwrap()).await;
245                    let mut guard = order.lock().await;
246                    guard.push(i);
247                })
248            }))).await;
249
250        }
251        tokio::time::sleep(Duration::from_secs(1)).await;
252        let guard = order.lock().await;
253        let expected = [9,8,7,6,5,4,3,2,1,0];
254        assert_eq!(guard.len(), expected.len());
255        for (e, r) in expected.iter().zip(&*guard) {
256            assert_eq!(e, r);
257        }
258    }
259}