1use crate::{RunToken, scope_guard::scope_guard};
3use futures_util::{
4 Future, FutureExt,
5 future::{self},
6 pin_mut,
7};
8use log::{debug, error, info};
9use std::{
10 borrow::Cow,
11 sync::{
12 Arc,
13 atomic::{AtomicUsize, Ordering},
14 },
15};
16use std::{collections::HashMap, sync::atomic::AtomicBool};
17use std::{fmt::Display, sync::Mutex};
18use std::{pin::Pin, task::Poll};
19use tokio::{
20 sync::Notify,
21 task::{JoinError, JoinHandle},
22};
23
24#[cfg(feature = "ordered-locks")]
25use ordered_locks::{CleanLockToken, L0, LockToken};
26
27static TASKS: Mutex<Option<HashMap<usize, Arc<dyn TaskBase>>>> = Mutex::new(None);
29static SHUTDOWN_NOTIFY: Notify = Notify::const_new();
31static TASK_ID_COUNT: AtomicUsize = AtomicUsize::new(0);
33static SHUTTING_DOWN: AtomicBool = AtomicBool::new(false);
35
36#[derive(Debug)]
38pub struct CancelledError {}
39impl Display for CancelledError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "CancelledError")
42 }
43}
44impl std::error::Error for CancelledError {}
45
46pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
48
49pub async fn cancelable<T, F: Future<Output = T>>(
51 run_token: &RunToken,
52 fut: F,
53) -> Result<T, CancelledError> {
54 let c = run_token.cancelled();
55 pin_mut!(fut, c);
56 let f = future::select(c, fut).await;
57 match f {
58 future::Either::Right((v, _)) => Ok(v),
59 future::Either::Left(_) => Err(CancelledError {}),
60 }
61}
62
63#[cfg(feature = "ordered-locks")]
65pub async fn cancelable_checked<T, F: Future<Output = T>>(
66 run_token: &RunToken,
67 lock_token: LockToken<'_, L0>,
68 fut: F,
69) -> Result<T, CancelledError> {
70 let c = run_token.cancelled_checked(lock_token);
71 pin_mut!(fut, c);
72 let f = future::select(c, fut).await;
73 match f {
74 future::Either::Right((v, _)) => Ok(v),
75 future::Either::Left(_) => Err(CancelledError {}),
76 }
77}
78
79#[doc(hidden)]
80#[derive(Debug)]
81pub enum FinishState<'a> {
82 Success,
83 Drop,
84 Abort,
85 JoinError(JoinError),
86 Failure(&'a (dyn std::fmt::Debug + Sync + Send)),
87}
88
89pub struct TaskBuilder {
91 id: usize,
93 name: Cow<'static, str>,
95 run_token: RunToken,
97 critical: bool,
99 main: bool,
101 abort: bool,
103 no_shutdown: bool,
105 shutdown_order: i32,
107}
108
109impl TaskBuilder {
110 pub fn new(name: impl Into<Cow<'static, str>>) -> Self {
112 Self {
113 id: TASK_ID_COUNT.fetch_add(1, Ordering::SeqCst),
114 name: name.into(),
115 run_token: Default::default(),
116 critical: false,
117 main: false,
118 abort: false,
119 no_shutdown: false,
120 shutdown_order: 0,
121 }
122 }
123
124 pub fn id(&self) -> usize {
126 self.id
127 }
128
129 pub fn set_run_token(self, run_token: RunToken) -> Self {
132 Self { run_token, ..self }
133 }
134
135 pub fn critical(self) -> Self {
137 Self {
138 critical: true,
139 ..self
140 }
141 }
142
143 pub fn main(self) -> Self {
145 Self { main: true, ..self }
146 }
147
148 pub fn abort(self) -> Self {
150 Self {
151 abort: true,
152 ..self
153 }
154 }
155
156 pub fn no_shutdown(self) -> Self {
158 Self {
159 no_shutdown: true,
160 ..self
161 }
162 }
163
164 pub fn shutdown_order(self, shutdown_order: i32) -> Self {
166 Self {
167 shutdown_order,
168 ..self
169 }
170 }
171
172 pub fn create<
174 T: 'static + Send + Sync,
175 E: std::fmt::Debug + Sync + Send + 'static,
176 Fu: Future<Output = Result<T, E>> + Send + 'static,
177 F: FnOnce(RunToken) -> Fu,
178 >(
179 self,
180 fun: F,
181 ) -> Arc<Task<T, E>> {
182 let fut = fun(self.run_token.clone());
183 let id = self.id;
184 let mut tasks = TASKS.lock().unwrap();
186 debug!("Started task {} ({})", self.name, id);
187 let join_handle = tokio::spawn(async move {
188 let g = scope_guard(|| {
189 if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
190 t._internal_handle_finished(FinishState::Drop);
191 }
192 });
193 let r = fut.await;
194 let s = match &r {
195 Ok(_) => FinishState::Success,
196 Err(e) => FinishState::Failure(e),
197 };
198 g.release();
199 if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
200 t._internal_handle_finished(s);
201 }
202 r
203 });
204 let task = Arc::new(Task {
205 id: self.id,
206 name: self.name,
207 critical: self.critical,
208 main: self.main,
209 abort: self.abort,
210 no_shutdown: self.no_shutdown,
211 shutdown_order: self.shutdown_order,
212 run_token: self.run_token,
213 start_time: std::time::SystemTime::now()
214 .duration_since(std::time::UNIX_EPOCH)
215 .unwrap()
216 .as_secs_f64(),
217 join_handle: Mutex::new(Some(join_handle)),
218 });
219 tasks.get_or_insert_default().insert(self.id, task.clone());
220 task
221 }
222
223 #[cfg(feature = "ordered-locks")]
225 pub fn create_with_lock_token<
226 T: 'static + Send + Sync,
227 E: std::fmt::Debug + Sync + Send + 'static,
228 Fu: Future<Output = Result<T, E>> + Send + 'static,
229 F: FnOnce(RunToken, CleanLockToken) -> Fu,
230 >(
231 self,
232 fun: F,
233 ) -> Arc<Task<T, E>> {
234 self.create(|run_token| fun(run_token, unsafe { CleanLockToken::new() }))
236 }
237}
238
239pub trait TaskBase: Send + Sync {
241 #[doc(hidden)]
242 fn _internal_handle_finished(&self, state: FinishState);
243 fn shutdown_order(&self) -> i32;
245 fn name(&self) -> &str;
247 fn id(&self) -> usize;
249 fn main(&self) -> bool;
251 fn abort(&self) -> bool;
253 fn critical(&self) -> bool;
255 fn start_time(&self) -> f64;
257 fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()>;
259 fn run_token(&self) -> &RunToken;
261 fn no_shutdown(&self) -> bool;
263}
264
265pub struct Task<T: Send + Sync, E: Sync + Sync> {
267 id: usize,
269 name: Cow<'static, str>,
271 critical: bool,
273 main: bool,
275 abort: bool,
277 no_shutdown: bool,
279 shutdown_order: i32,
281 run_token: RunToken,
283 start_time: f64,
285 join_handle: Mutex<Option<JoinHandle<Result<T, E>>>>,
287}
288
289impl<T: Send + Sync + 'static, E: Send + Sync + 'static> TaskBase for Task<T, E> {
290 fn shutdown_order(&self) -> i32 {
291 self.shutdown_order
292 }
293
294 fn name(&self) -> &str {
295 self.name.as_ref()
296 }
297
298 fn id(&self) -> usize {
299 self.id
300 }
301
302 fn _internal_handle_finished(&self, state: FinishState) {
303 match state {
304 FinishState::Success => {
305 if !self.main
306 || !shutdown(format!(
307 "Main task {} ({}) finished unexpected",
308 self.name, self.id
309 ))
310 {
311 debug!("Finished task {} ({})", self.name, self.id);
312 }
313 }
314 FinishState::Drop => {
315 if self.main || self.critical {
316 if shutdown(format!("Critical task {} ({}) dropped", self.name, self.id)) {
317 } else if !self.abort {
318 error!("Critical task {} ({}) dropped", self.name, self.id);
320 } else {
321 debug!("Critical task {} ({}) dropped", self.name, self.id)
322 }
323 } else if !self.abort {
324 error!("Task {} ({}) dropped", self.name, self.id);
326 } else {
327 debug!("Task {} ({}) dropped", self.name, self.id)
328 }
329 }
330 FinishState::JoinError(e) => {
331 if (!self.main && !self.critical)
332 || !shutdown(format!(
333 "Join error in critical task {} ({}): {:?}",
334 self.name, self.id, e
335 ))
336 {
337 error!("Join error in task {} ({}): {:?}", self.name, self.id, e);
338 }
339 }
340 FinishState::Failure(e) => {
341 if (!self.main && !self.critical)
342 || !shutdown(format!(
343 "Failure in critical task {} ({}) @ {:?}: {:?}",
344 self.name,
345 self.id,
346 self.run_token().location(),
347 e
348 ))
349 {
350 let location = self.run_token().location();
351 error!(
352 "Failure in task {} ({}) @ {:?}: {:?}",
353 self.name, self.id, location, e
354 );
355 }
356 }
357 FinishState::Abort => {
358 if !self.main
359 || !shutdown(format!(
360 "Main task {} ({}) aborted unexpected",
361 self.name, self.id
362 ))
363 {
364 debug!("Aborted task {} ({})", self.name, self.id);
365 }
366 }
367 }
368 }
369
370 fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()> {
371 Box::pin(self.cancel())
372 }
373
374 fn main(&self) -> bool {
375 self.main
376 }
377
378 fn abort(&self) -> bool {
379 self.abort
380 }
381
382 fn critical(&self) -> bool {
383 self.critical
384 }
385
386 fn start_time(&self) -> f64 {
387 self.start_time
388 }
389
390 fn run_token(&self) -> &RunToken {
391 &self.run_token
392 }
393
394 fn no_shutdown(&self) -> bool {
395 self.no_shutdown
396 }
397}
398
399#[derive(Debug)]
401pub enum WaitError<E: Send + Sync> {
402 HandleUnset(String),
404 JoinError(tokio::task::JoinError),
406 TaskFailure(E),
408}
409
410impl<E: std::fmt::Display + Send + Sync> std::fmt::Display for WaitError<E> {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 match self {
413 WaitError::HandleUnset(v) => write!(f, "Handle unset: {v}"),
414 WaitError::JoinError(v) => write!(f, "Join Error: {v}"),
415 WaitError::TaskFailure(v) => write!(f, "Task Failure: {v}"),
416 }
417 }
418}
419
420impl<E: std::error::Error + Send + Sync> std::error::Error for WaitError<E> {}
421
422struct TaskJoinHandleBorrow<'a, T: Send + Sync, E: Send + Sync> {
424 task: &'a Arc<Task<T, E>>,
426 jh: Option<JoinHandle<Result<T, E>>>,
428}
429
430impl<'a, T: Send + Sync, E: Send + Sync> TaskJoinHandleBorrow<'a, T, E> {
431 fn new(task: &'a Arc<Task<T, E>>) -> Self {
433 let jh = task.join_handle.lock().unwrap().take();
434 Self { task, jh }
435 }
436}
437
438impl<'a, T: Send + Sync, E: Send + Sync> Drop for TaskJoinHandleBorrow<'a, T, E> {
439 fn drop(&mut self) {
440 *self.task.join_handle.lock().unwrap() = self.jh.take();
441 }
442}
443
444impl<T: Send + Sync, E: Send + Sync> Task<T, E> {
445 pub async fn cancel(self: Arc<Self>) {
449 let mut b = TaskJoinHandleBorrow::new(&self);
450 self.run_token.cancel();
451 if let Some(jh) = &mut b.jh {
452 if self.abort {
453 jh.abort();
454 let _ = jh.await;
455 if let Some(t) = TASKS
456 .lock()
457 .unwrap()
458 .get_or_insert_default()
459 .remove(&self.id)
460 {
461 t._internal_handle_finished(FinishState::Abort);
462 }
463 } else if let Err(e) = jh.await {
464 info!("Unable to join task {e:?}");
465 if let Some(t) = TASKS
466 .lock()
467 .unwrap()
468 .get_or_insert_default()
469 .remove(&self.id)
470 {
471 t._internal_handle_finished(FinishState::JoinError(e));
472 }
473 }
474 }
475 if !SHUTTING_DOWN.load(Ordering::SeqCst) {
476 info!(" canceled {} ({})", self.name, self.id);
477 }
478 std::mem::forget(b);
479 }
480
481 pub async fn wait(self: Arc<Self>) -> Result<T, WaitError<E>> {
483 let mut b = TaskJoinHandleBorrow::new(&self);
484 let r = match &mut b.jh {
485 None => Err(WaitError::HandleUnset(self.name.to_string())),
486 Some(jh) => match jh.await {
487 Ok(Ok(v)) => Ok(v),
488 Ok(Err(e)) => Err(WaitError::TaskFailure(e)),
489 Err(e) => Err(WaitError::JoinError(e)),
490 },
491 };
492 std::mem::forget(b);
493 r
494 }
495}
496
497struct WaitTasks<'a, Sleep, Fut>(Sleep, &'a mut Vec<(String, usize, Fut, RunToken)>);
499impl<'a, Sleep: Unpin, Fut: Unpin> Unpin for WaitTasks<'a, Sleep, Fut> {}
500impl<'a, Sleep: Future + Unpin, Fut: Future + Unpin> Future for WaitTasks<'a, Sleep, Fut> {
501 type Output = bool;
502
503 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<bool> {
504 if self.0.poll_unpin(cx).is_ready() {
505 return Poll::Ready(false);
506 }
507
508 self.1
509 .retain_mut(|(_, _, f, _)| !matches!(f.poll_unpin(cx), Poll::Ready(_)));
510
511 if self.1.is_empty() {
512 Poll::Ready(true)
513 } else {
514 Poll::Pending
515 }
516 }
517}
518
519pub fn shutdown(message: String) -> bool {
521 if SHUTTING_DOWN
522 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
523 .is_err()
524 {
525 return false;
527 }
528 info!("Shutting down: {message}");
529 tokio::spawn(async move {
530 let mut shutdown_tasks: Vec<Arc<dyn TaskBase>> = Vec::new();
531 loop {
532 for (_, task) in TASKS.lock().unwrap().get_or_insert_default().iter() {
533 if task.no_shutdown() {
534 continue;
535 }
536 if let Some(t) = shutdown_tasks.first() {
537 if t.shutdown_order() < task.shutdown_order() {
538 continue;
539 }
540 if t.shutdown_order() > task.shutdown_order() {
541 shutdown_tasks.clear();
542 }
543 }
544 shutdown_tasks.push(task.clone());
545 }
546 if shutdown_tasks.is_empty() {
547 break;
548 }
549 info!(
550 "shutting down {} tasks with order {}",
551 shutdown_tasks.len(),
552 shutdown_tasks[0].shutdown_order()
553 );
554 let mut stop_futures: Vec<(String, usize, _, RunToken)> = shutdown_tasks
555 .iter()
556 .map(|t| {
557 (
558 t.name().to_string(),
559 t.id(),
560 t.clone().cancel(),
561 t.run_token().clone(),
562 )
563 })
564 .collect();
565 while !WaitTasks(
566 Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(30))),
567 &mut stop_futures,
568 )
569 .await
570 {
571 info!("still waiting for {} tasks", stop_futures.len(),);
572 for (name, id, _, rt) in &stop_futures {
573 if let Some((file, line)) = rt.location() {
574 info!(" {name} ({id}) at {file}:{line}");
575 } else {
576 info!(" {name} ({id})");
577 }
578 }
579 }
580 shutdown_tasks.clear();
581 }
582 info!("shutdown done");
583 SHUTDOWN_NOTIFY.notify_waiters();
584 });
585 true
586}
587
588pub async fn run_tasks() {
590 SHUTDOWN_NOTIFY.notified().await
591}
592
593pub fn list_tasks() -> Vec<Arc<dyn TaskBase>> {
595 TASKS
596 .lock()
597 .unwrap()
598 .get_or_insert_default()
599 .values()
600 .cloned()
601 .collect()
602}
603
604pub fn try_list_tasks_for(duration: std::time::Duration) -> Option<Vec<Arc<dyn TaskBase>>> {
607 let tries = 50;
608 for _ in 0..tries {
609 if let Ok(mut tasks) = TASKS.try_lock() {
610 return Some(tasks.get_or_insert_default().values().cloned().collect());
611 }
612 std::thread::sleep(duration / tries);
613 }
614 if let Ok(mut tasks) = TASKS.try_lock() {
615 return Some(tasks.get_or_insert_default().values().cloned().collect());
616 }
617 None
618}