1use crate::{RunToken, scope_guard::scope_guard};
2use futures_util::{
3    Future, FutureExt,
4    future::{self},
5    pin_mut,
6};
7use log::{debug, error, info};
8use std::{
9    borrow::Cow,
10    sync::{
11        Arc,
12        atomic::{AtomicUsize, Ordering},
13    },
14};
15use std::{collections::HashMap, sync::atomic::AtomicBool};
16use std::{fmt::Display, sync::Mutex};
17use std::{pin::Pin, task::Poll};
18use tokio::{
19    sync::Notify,
20    task::{JoinError, JoinHandle},
21};
22
23#[cfg(feature = "ordered-locks")]
24use ordered_locks::{CleanLockToken, L0, LockToken};
25
26static TASKS: Mutex<Option<HashMap<usize, Arc<dyn TaskBase>>>> = Mutex::new(None);
27static SHUTDOWN_NOTIFY: Notify = Notify::const_new();
28static TASK_ID_COUNT: AtomicUsize = AtomicUsize::new(0);
29static SHUTTING_DOWN: AtomicBool = AtomicBool::new(false);
30
31#[derive(Debug)]
33pub struct CancelledError {}
34impl Display for CancelledError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "CancelledError")
37    }
38}
39impl std::error::Error for CancelledError {}
40
41pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
42
43pub async fn cancelable<T, F: Future<Output = T>>(
45    run_token: &RunToken,
46    fut: F,
47) -> Result<T, CancelledError> {
48    let c = run_token.cancelled();
49    pin_mut!(fut, c);
50    let f = future::select(c, fut).await;
51    match f {
52        future::Either::Right((v, _)) => Ok(v),
53        future::Either::Left(_) => Err(CancelledError {}),
54    }
55}
56
57#[cfg(feature = "ordered-locks")]
59pub async fn cancelable_checked<T, F: Future<Output = T>>(
60    run_token: &RunToken,
61    lock_token: LockToken<'_, L0>,
62    fut: F,
63) -> Result<T, CancelledError> {
64    let c = run_token.cancelled_checked(lock_token);
65    pin_mut!(fut, c);
66    let f = future::select(c, fut).await;
67    match f {
68        future::Either::Right((v, _)) => Ok(v),
69        future::Either::Left(_) => Err(CancelledError {}),
70    }
71}
72
73#[doc(hidden)]
74#[derive(Debug)]
75pub enum FinishState<'a> {
76    Success,
77    Drop,
78    Abort,
79    JoinError(JoinError),
80    Failure(&'a (dyn std::fmt::Debug + Sync + Send)),
81}
82
83pub struct TaskBuilder {
85    id: usize,
86    name: Cow<'static, str>,
87    run_token: RunToken,
88    critical: bool,
89    main: bool,
90    abort: bool,
91    no_shutdown: bool,
92    shutdown_order: i32,
93}
94
95impl TaskBuilder {
96    pub fn new(name: impl Into<Cow<'static, str>>) -> Self {
98        Self {
99            id: TASK_ID_COUNT.fetch_add(1, Ordering::SeqCst),
100            name: name.into(),
101            run_token: Default::default(),
102            critical: false,
103            main: false,
104            abort: false,
105            no_shutdown: false,
106            shutdown_order: 0,
107        }
108    }
109
110    pub fn id(&self) -> usize {
112        self.id
113    }
114
115    pub fn set_run_token(self, run_token: RunToken) -> Self {
118        Self { run_token, ..self }
119    }
120
121    pub fn critical(self) -> Self {
123        Self {
124            critical: true,
125            ..self
126        }
127    }
128
129    pub fn main(self) -> Self {
131        Self { main: true, ..self }
132    }
133
134    pub fn abort(self) -> Self {
136        Self {
137            abort: true,
138            ..self
139        }
140    }
141
142    pub fn no_shutdown(self) -> Self {
144        Self {
145            no_shutdown: true,
146            ..self
147        }
148    }
149
150    pub fn shutdown_order(self, shutdown_order: i32) -> Self {
152        Self {
153            shutdown_order,
154            ..self
155        }
156    }
157
158    pub fn create<
160        T: 'static + Send + Sync,
161        E: std::fmt::Debug + Sync + Send + 'static,
162        Fu: Future<Output = Result<T, E>> + Send + 'static,
163        F: FnOnce(RunToken) -> Fu,
164    >(
165        self,
166        fun: F,
167    ) -> Arc<Task<T, E>> {
168        let fut = fun(self.run_token.clone());
169        let id = self.id;
170        let mut tasks = TASKS.lock().unwrap();
172        debug!("Started task {} ({})", self.name, id);
173        let join_handle = tokio::spawn(async move {
174            let g = scope_guard(|| {
175                if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
176                    t._internal_handle_finished(FinishState::Drop);
177                }
178            });
179            let r = fut.await;
180            let s = match &r {
181                Ok(_) => FinishState::Success,
182                Err(e) => FinishState::Failure(e),
183            };
184            g.release();
185            if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
186                t._internal_handle_finished(s);
187            }
188            r
189        });
190        let task = Arc::new(Task {
191            id: self.id,
192            name: self.name,
193            critical: self.critical,
194            main: self.main,
195            abort: self.abort,
196            no_shutdown: self.no_shutdown,
197            shutdown_order: self.shutdown_order,
198            run_token: self.run_token,
199            start_time: std::time::SystemTime::now()
200                .duration_since(std::time::UNIX_EPOCH)
201                .unwrap()
202                .as_secs_f64(),
203            join_handle: Mutex::new(Some(join_handle)),
204        });
205        tasks.get_or_insert_default().insert(self.id, task.clone());
206        task
207    }
208
209    #[cfg(feature = "ordered-locks")]
211    pub fn create_with_lock_token<
212        T: 'static + Send + Sync,
213        E: std::fmt::Debug + Sync + Send + 'static,
214        Fu: Future<Output = Result<T, E>> + Send + 'static,
215        F: FnOnce(RunToken, CleanLockToken) -> Fu,
216    >(
217        self,
218        fun: F,
219    ) -> Arc<Task<T, E>> {
220        self.create(|run_token| fun(run_token, unsafe { CleanLockToken::new() }))
221    }
222}
223
224pub trait TaskBase: Send + Sync {
226    #[doc(hidden)]
227    fn _internal_handle_finished(&self, state: FinishState);
228    fn shutdown_order(&self) -> i32;
230    fn name(&self) -> &str;
232    fn id(&self) -> usize;
234    fn main(&self) -> bool;
236    fn abort(&self) -> bool;
238    fn critical(&self) -> bool;
240    fn start_time(&self) -> f64;
242    fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()>;
244    fn run_token(&self) -> &RunToken;
246    fn no_shutdown(&self) -> bool;
248}
249
250pub struct Task<T: Send + Sync, E: Sync + Sync> {
252    id: usize,
253    name: Cow<'static, str>,
254    critical: bool,
255    main: bool,
256    abort: bool,
257    no_shutdown: bool,
258    shutdown_order: i32,
259    run_token: RunToken,
260    start_time: f64,
261    join_handle: Mutex<Option<JoinHandle<Result<T, E>>>>,
262}
263
264impl<T: Send + Sync + 'static, E: Send + Sync + 'static> TaskBase for Task<T, E> {
265    fn shutdown_order(&self) -> i32 {
266        self.shutdown_order
267    }
268
269    fn name(&self) -> &str {
270        self.name.as_ref()
271    }
272
273    fn id(&self) -> usize {
274        self.id
275    }
276
277    fn _internal_handle_finished(&self, state: FinishState) {
278        match state {
279            FinishState::Success => {
280                if !self.main
281                    || !shutdown(format!(
282                        "Main task {} ({}) finished unexpected",
283                        self.name, self.id
284                    ))
285                {
286                    debug!("Finished task {} ({})", self.name, self.id);
287                }
288            }
289            FinishState::Drop => {
290                if self.main || self.critical {
291                    if shutdown(format!("Critical task {} ({}) dropped", self.name, self.id)) {
292                    } else if !self.abort {
293                        error!("Critical task {} ({}) dropped", self.name, self.id);
295                    } else {
296                        debug!("Critical task {} ({}) dropped", self.name, self.id)
297                    }
298                } else if !self.abort {
299                    error!("Task {} ({}) dropped", self.name, self.id);
301                } else {
302                    debug!("Task {} ({}) dropped", self.name, self.id)
303                }
304            }
305            FinishState::JoinError(e) => {
306                if (!self.main && !self.critical)
307                    || !shutdown(format!(
308                        "Join error in critical task {} ({}): {:?}",
309                        self.name, self.id, e
310                    ))
311                {
312                    error!("Join error in task {} ({}): {:?}", self.name, self.id, e);
313                }
314            }
315            FinishState::Failure(e) => {
316                if (!self.main && !self.critical)
317                    || !shutdown(format!(
318                        "Failure in critical task {} ({}) @ {:?}: {:?}",
319                        self.name,
320                        self.id,
321                        self.run_token().location(),
322                        e
323                    ))
324                {
325                    let location = self.run_token().location();
326                    error!(
327                        "Failure in task {} ({}) @ {:?}: {:?}",
328                        self.name, self.id, location, e
329                    );
330                }
331            }
332            FinishState::Abort => {
333                if !self.main
334                    || !shutdown(format!(
335                        "Main task {} ({}) aborted unexpected",
336                        self.name, self.id
337                    ))
338                {
339                    debug!("Aborted task {} ({})", self.name, self.id);
340                }
341            }
342        }
343    }
344
345    fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()> {
346        Box::pin(self.cancel())
347    }
348
349    fn main(&self) -> bool {
350        self.main
351    }
352
353    fn abort(&self) -> bool {
354        self.abort
355    }
356
357    fn critical(&self) -> bool {
358        self.critical
359    }
360
361    fn start_time(&self) -> f64 {
362        self.start_time
363    }
364
365    fn run_token(&self) -> &RunToken {
366        &self.run_token
367    }
368
369    fn no_shutdown(&self) -> bool {
370        self.no_shutdown
371    }
372}
373
374#[derive(Debug)]
376pub enum WaitError<E: Send + Sync> {
377    HandleUnset(String),
379    JoinError(tokio::task::JoinError),
381    TaskFailure(E),
383}
384
385impl<E: std::fmt::Display + Send + Sync> std::fmt::Display for WaitError<E> {
386    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387        match self {
388            WaitError::HandleUnset(v) => write!(f, "Handle unset: {v}"),
389            WaitError::JoinError(v) => write!(f, "Join Error: {v}"),
390            WaitError::TaskFailure(v) => write!(f, "Task Failure: {v}"),
391        }
392    }
393}
394
395impl<E: std::error::Error + Send + Sync> std::error::Error for WaitError<E> {}
396
397struct TaskJoinHandleBorrow<'a, T: Send + Sync, E: Send + Sync> {
398    task: &'a Arc<Task<T, E>>,
399    jh: Option<JoinHandle<Result<T, E>>>,
400}
401
402impl<'a, T: Send + Sync, E: Send + Sync> TaskJoinHandleBorrow<'a, T, E> {
403    fn new(task: &'a Arc<Task<T, E>>) -> Self {
404        let jh = task.join_handle.lock().unwrap().take();
405        Self { task, jh }
406    }
407}
408
409impl<'a, T: Send + Sync, E: Send + Sync> Drop for TaskJoinHandleBorrow<'a, T, E> {
410    fn drop(&mut self) {
411        *self.task.join_handle.lock().unwrap() = self.jh.take();
412    }
413}
414
415impl<T: Send + Sync, E: Send + Sync> Task<T, E> {
416    pub async fn cancel(self: Arc<Self>) {
420        let mut b = TaskJoinHandleBorrow::new(&self);
421        self.run_token.cancel();
422        if let Some(jh) = &mut b.jh {
423            if self.abort {
424                jh.abort();
425                let _ = jh.await;
426                if let Some(t) = TASKS
427                    .lock()
428                    .unwrap()
429                    .get_or_insert_default()
430                    .remove(&self.id)
431                {
432                    t._internal_handle_finished(FinishState::Abort);
433                }
434            } else if let Err(e) = jh.await {
435                info!("Unable to join task {e:?}");
436                if let Some(t) = TASKS
437                    .lock()
438                    .unwrap()
439                    .get_or_insert_default()
440                    .remove(&self.id)
441                {
442                    t._internal_handle_finished(FinishState::JoinError(e));
443                }
444            }
445        }
446        if !SHUTTING_DOWN.load(Ordering::SeqCst) {
447            info!("  canceled {} ({})", self.name, self.id);
448        }
449        std::mem::forget(b);
450    }
451
452    pub async fn wait(self: Arc<Self>) -> Result<T, WaitError<E>> {
454        let mut b = TaskJoinHandleBorrow::new(&self);
455        let r = match &mut b.jh {
456            None => Err(WaitError::HandleUnset(self.name.to_string())),
457            Some(jh) => match jh.await {
458                Ok(Ok(v)) => Ok(v),
459                Ok(Err(e)) => Err(WaitError::TaskFailure(e)),
460                Err(e) => Err(WaitError::JoinError(e)),
461            },
462        };
463        std::mem::forget(b);
464        r
465    }
466}
467struct WaitTasks<'a, Sleep, Fut>(Sleep, &'a mut Vec<(String, usize, Fut, RunToken)>);
468impl<'a, Sleep: Unpin, Fut: Unpin> Unpin for WaitTasks<'a, Sleep, Fut> {}
469impl<'a, Sleep: Future + Unpin, Fut: Future + Unpin> Future for WaitTasks<'a, Sleep, Fut> {
470    type Output = bool;
471
472    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<bool> {
473        if self.0.poll_unpin(cx).is_ready() {
474            return Poll::Ready(false);
475        }
476
477        self.1
478            .retain_mut(|(_, _, f, _)| !matches!(f.poll_unpin(cx), Poll::Ready(_)));
479
480        if self.1.is_empty() {
481            Poll::Ready(true)
482        } else {
483            Poll::Pending
484        }
485    }
486}
487
488pub fn shutdown(message: String) -> bool {
490    if SHUTTING_DOWN
491        .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
492        .is_err()
493    {
494        return false;
496    }
497    info!("Shutting down: {message}");
498    tokio::spawn(async move {
499        let mut shutdown_tasks: Vec<Arc<dyn TaskBase>> = Vec::new();
500        loop {
501            for (_, task) in TASKS.lock().unwrap().get_or_insert_default().iter() {
502                if task.no_shutdown() {
503                    continue;
504                }
505                if let Some(t) = shutdown_tasks.first() {
506                    if t.shutdown_order() < task.shutdown_order() {
507                        continue;
508                    }
509                    if t.shutdown_order() > task.shutdown_order() {
510                        shutdown_tasks.clear();
511                    }
512                }
513                shutdown_tasks.push(task.clone());
514            }
515            if shutdown_tasks.is_empty() {
516                break;
517            }
518            info!(
519                "shutting down {} tasks with order {}",
520                shutdown_tasks.len(),
521                shutdown_tasks[0].shutdown_order()
522            );
523            let mut stop_futures: Vec<(String, usize, _, RunToken)> = shutdown_tasks
524                .iter()
525                .map(|t| {
526                    (
527                        t.name().to_string(),
528                        t.id(),
529                        t.clone().cancel(),
530                        t.run_token().clone(),
531                    )
532                })
533                .collect();
534            while !WaitTasks(
535                Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(30))),
536                &mut stop_futures,
537            )
538            .await
539            {
540                info!("still waiting for {} tasks", stop_futures.len(),);
541                for (name, id, _, rt) in &stop_futures {
542                    if let Some((file, line)) = rt.location() {
543                        info!("  {name} ({id}) at {file}:{line}");
544                    } else {
545                        info!("  {name} ({id})");
546                    }
547                }
548            }
549            shutdown_tasks.clear();
550        }
551        info!("shutdown done");
552        SHUTDOWN_NOTIFY.notify_waiters();
553    });
554    true
555}
556
557pub async fn run_tasks() {
559    SHUTDOWN_NOTIFY.notified().await
560}
561
562pub fn list_tasks() -> Vec<Arc<dyn TaskBase>> {
564    TASKS
565        .lock()
566        .unwrap()
567        .get_or_insert_default()
568        .values()
569        .cloned()
570        .collect()
571}
572
573pub fn try_list_tasks_for(duration: std::time::Duration) -> Option<Vec<Arc<dyn TaskBase>>> {
576    let tries = 50;
577    for _ in 0..tries {
578        if let Ok(mut tasks) = TASKS.try_lock() {
579            return Some(tasks.get_or_insert_default().values().cloned().collect());
580        }
581        std::thread::sleep(duration / tries);
582    }
583    if let Ok(mut tasks) = TASKS.try_lock() {
584        return Some(tasks.get_or_insert_default().values().cloned().collect());
585    }
586    None
587}