1use std::{collections::BTreeSet, sync::Arc, time::SystemTime};
2
3use async_recursion::async_recursion;
4use tokio::{sync::{mpsc, MutexGuard}, task::JoinHandle};
5
6use crate::{message::MessageType, task::Task};
7
8pub(crate) type AsyncRt = tokio::runtime::Handle;
9
10struct TaskHandle {
11 task: Task,
12 handle: JoinHandle<()>,
13}
14struct _EventQueue {
15 tid: tokio::sync::Mutex<usize>,
16 tasks: tokio::sync::Mutex<Vec<Task>>,
17 running: tokio::sync::Mutex<Option<TaskHandle>>,
18 tasks_cancelled: tokio::sync::Mutex<BTreeSet<usize>>,
19 tx: mpsc::Sender<MessageType>,
20 _rt: AsyncRt,
21}
22impl _EventQueue {
23 fn new(rt: AsyncRt, tx: mpsc::Sender<MessageType>) -> Self {
24 Self {
25 tid: tokio::sync::Mutex::new(1),
26 tasks: tokio::sync::Mutex::new(Vec::new()),
27 running: tokio::sync::Mutex::new(None),
28 tasks_cancelled: tokio::sync::Mutex::new(BTreeSet::new()),
29 tx,
30 _rt: rt,
31 }
32 }
33 async fn lock_tasks(&self) -> MutexGuard<Vec<Task>> {
34 self.tasks.lock().await
35 }
36 async fn lock_running(&self) -> MutexGuard<Option<TaskHandle>> {
37 self.running.lock().await
38 }
39 async fn lock_tasks_cancelled(&self) -> MutexGuard<BTreeSet<usize>>{
40 self.tasks_cancelled.lock().await
41 }
42 async fn get_tid(&self) -> usize {
43 let mut guard = self.tid.lock().await;
44 let tid = *guard;
45 *guard += 1;
46 tid
47 }
48}
49async fn _run_task(sq: Arc<_EventQueue>, task: &Task) -> Option<JoinHandle<()>> {
50 let sqc = sq.clone();
51 let task = task.clone();
52 let task_id = task.id.expect("_run_task id not assigned to task");
53 let mut tasks_cancelled_guard = sq.lock_tasks_cancelled().await;
55 if tasks_cancelled_guard.remove(&task.id.unwrap()) {
56 sqc.tx.send(MessageType::Cancel(task_id, SystemTime::now())).await.expect("_run_task message sending failed");
57 return None;
58 }
59
60 Some(sq._rt.spawn(async move {
61 sqc.tx.send(MessageType::WaitStart(task_id, SystemTime::now())).await.expect("_run_task message sending failed");
62 task.clone().await;
63 sqc.tx.send(MessageType::WaitEnd(task_id, SystemTime::now())).await.expect("_run_task message sending failed");
64 _add(sqc.clone(), task).await;
65 _next(sqc.clone()).await;
66 }))
67}
68
69#[async_recursion]
70async fn _next(sq: Arc<_EventQueue>) {
71 let mut tasks_guard = sq.lock_tasks().await;
72 let mut running_guard = sq.lock_running().await;
73 *running_guard = None;
74
75 if let Some(task) = tasks_guard.pop() {
76 let handle = _run_task(sq.clone(), &task).await;
77 if let Some(handle) = handle {
78 *running_guard = Some(TaskHandle {
79 task,
80 handle
81 });
82 }
83 }
84}
85#[async_recursion]
86async fn _add(sq: Arc<_EventQueue>, mut task: Task) -> Option<usize> {
87 let mut tasks_guard = sq.tasks.lock().await;
88 let mut running_guard = sq.running.lock().await;
89
90 let task_id: usize;
92 match task.id {
93 Some(id) => task_id = id,
94 None => {
95 task_id = sq.get_tid().await;
96 task.id = Some(task_id);
97 }
98
99 }
100
101 task.ready(sq._rt.clone());
103
104 match task.timestamp {
106 Some(_) => {
107 sq.tx.send(MessageType::Add(task_id, SystemTime::now())).await.expect("_add: message sending failed");
108 },
109 None => {
110 return None;
111 }
112 }
113
114 let taskhandle = running_guard.take();
115 if let Some(t) = taskhandle {
116 let Some(running_timestamp) = t.task.timestamp else { return None; };
117 let Some(new_timestamp) = task.timestamp else { return None; };
118
119 if new_timestamp < running_timestamp {
121 t.handle.abort();
122 tasks_guard.push(t.task);
123 sq.tx.send(MessageType::Abort(task_id, SystemTime::now())).await.expect("_add: message sending failed");
124 }
125 else {
126 *running_guard = Some(t);
127 }
128 }
129
130 tasks_guard.push(task);
131 tasks_guard.sort_by(|a, b| {
132 b.timestamp.unwrap().cmp(&a.timestamp.unwrap())
133 });
134
135 if running_guard.is_none() {
136 if let Some(task) = tasks_guard.pop() {
137 let handle = _run_task(sq.clone(), &task).await;
138 if let Some(handle) = handle {
139 *running_guard = Some(TaskHandle {
140 task,
141 handle
142 });
143 }
144 }
145 }
146
147 return Some(task_id);
148}
149
150async fn _cancel(sq: Arc<_EventQueue>, id: usize) -> bool {
151 let mut tasks_cancelled_guard = sq.lock_tasks_cancelled().await;
152 tasks_cancelled_guard.insert(id)
153}
154
155#[derive(Clone)]
156pub struct SchedQueue {
157 eq: Arc<_EventQueue>,
158}
159impl SchedQueue {
160 pub fn new() -> (Self, mpsc::Receiver<MessageType>) {
163 let (tx, rx) = mpsc::channel(1000);
164 (Self {
165 eq: Arc::new(_EventQueue::new(tokio::runtime::Handle::current(), tx)),
166 }, rx)
167 }
168 pub async fn add(&self, task: Task) -> Option<usize> {
171 _add(self.eq.clone(), task).await
172 }
173
174 pub async fn cancel(&self, task_id: usize) -> bool {
175 _cancel(self.eq.clone(), task_id).await
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use std::{sync::Arc, time::{Duration, SystemTime}};
182 use tokio::sync::Mutex;
183 use crate::{queue::SchedQueue, task::{SchedType, Task}};
184
185 #[tokio::test]
186 async fn timestamp_test() {
187 let (sq, _rx) = SchedQueue::new();
188
189 let reserved_time = SystemTime::now() + Duration::from_millis(500);
190 let executed_time = Arc::new(Mutex::new(SystemTime::now()));
191
192 let ex = executed_time.clone();
193 let _ = sq.add(Task::new(SchedType::Timestamp(reserved_time), Box::new(move ||{
194 let ex = ex.clone();
195 Box::pin(async move {
196 let mut guard = ex.lock().await;
197 *guard = SystemTime::now();
198 })
199 }))).await;
200 tokio::time::sleep(Duration::from_secs(1)).await;
201 let guard = executed_time.lock().await;
202
203 let diff = guard.duration_since(reserved_time).unwrap();
204 assert!(diff < Duration::from_millis(10));
206 }
207 #[tokio::test]
208 async fn delay_test() {
209 let (sq, _rx) = SchedQueue::new();
210 let order = Arc::new(Mutex::new(Vec::new()));
211
212 for i in 0..10 {
213 let order = order.clone();
214 let _ = sq.add(Task::new(SchedType::Delay(Duration::from_millis(101-i), 2), Box::new(move || {
215 let order = order.clone();
216 Box::pin(async move {
217 let mut guard = order.lock().await;
218 guard.push(i);
219 })
220 }))).await;
221
222 }
223 tokio::time::sleep(Duration::from_secs(1)).await;
224 let guard = order.lock().await;
225 let expected = [9,8,7,6,5,4,3,2,1,0,9,8,7,6,5,4,3,2,1,0];
226 assert_eq!(guard.len(), expected.len());
227 for (e, r) in expected.iter().zip(&*guard) {
228 assert_eq!(e, r);
229 }
230
231 }
232 #[tokio::test]
233 async fn cancel_test() {
234 let (sq, _rx) = SchedQueue::new();
235 let order = Arc::new(Mutex::new(Vec::new()));
236
237 for i in 0..10 {
238 let order = order.clone();
239 let sqc = sq.clone();
240 let _ = sq.add(Task::new(SchedType::Delay(Duration::from_millis(101-i), 2), Box::new(move || {
241 let order = order.clone();
242 let sqc = sqc.clone();
243 Box::pin(async move {
244 sqc.cancel((i+1).try_into().unwrap()).await;
245 let mut guard = order.lock().await;
246 guard.push(i);
247 })
248 }))).await;
249
250 }
251 tokio::time::sleep(Duration::from_secs(1)).await;
252 let guard = order.lock().await;
253 let expected = [9,8,7,6,5,4,3,2,1,0];
254 assert_eq!(guard.len(), expected.len());
255 for (e, r) in expected.iter().zip(&*guard) {
256 assert_eq!(e, r);
257 }
258 }
259}