tokio_tasks/
task.rs

1//! Implement task management for tokio
2use crate::{RunToken, scope_guard::scope_guard};
3use futures_util::{
4    Future, FutureExt,
5    future::{self},
6    pin_mut,
7};
8use log::{debug, error, info};
9use std::{
10    borrow::Cow,
11    sync::{
12        Arc,
13        atomic::{AtomicUsize, Ordering},
14    },
15};
16use std::{collections::HashMap, sync::atomic::AtomicBool};
17use std::{fmt::Display, sync::Mutex};
18use std::{pin::Pin, task::Poll};
19use tokio::{
20    sync::Notify,
21    task::{JoinError, JoinHandle},
22};
23
24#[cfg(feature = "ordered-locks")]
25use ordered_locks::{CleanLockToken, L0, LockToken};
26
27/// [HashMap] of all running tasks
28static TASKS: Mutex<Option<HashMap<usize, Arc<dyn TaskBase>>>> = Mutex::new(None);
29/// Notify this when we should shut down
30static SHUTDOWN_NOTIFY: Notify = Notify::const_new();
31/// Incremental counter for task ids
32static TASK_ID_COUNT: AtomicUsize = AtomicUsize::new(0);
33/// Atomic boolean indicating if we are currently shutting down
34static SHUTTING_DOWN: AtomicBool = AtomicBool::new(false);
35
36/// Error returned by [cancelable] when c was canceled before the future returned
37#[derive(Debug)]
38pub struct CancelledError {}
39impl Display for CancelledError {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "CancelledError")
42    }
43}
44impl std::error::Error for CancelledError {}
45
46/// A pinned, boxed future of T
47pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
48
49/// Return result from fut, unless run_token is canceled before fut is done
50pub async fn cancelable<T, F: Future<Output = T>>(
51    run_token: &RunToken,
52    fut: F,
53) -> Result<T, CancelledError> {
54    let c = run_token.cancelled();
55    pin_mut!(fut, c);
56    let f = future::select(c, fut).await;
57    match f {
58        future::Either::Right((v, _)) => Ok(v),
59        future::Either::Left(_) => Err(CancelledError {}),
60    }
61}
62
63/// Return result from fut, unless run_token is canceled before fut is done
64#[cfg(feature = "ordered-locks")]
65pub async fn cancelable_checked<T, F: Future<Output = T>>(
66    run_token: &RunToken,
67    lock_token: LockToken<'_, L0>,
68    fut: F,
69) -> Result<T, CancelledError> {
70    let c = run_token.cancelled_checked(lock_token);
71    pin_mut!(fut, c);
72    let f = future::select(c, fut).await;
73    match f {
74        future::Either::Right((v, _)) => Ok(v),
75        future::Either::Left(_) => Err(CancelledError {}),
76    }
77}
78
79#[doc(hidden)]
80#[derive(Debug)]
81pub enum FinishState<'a> {
82    Success,
83    Drop,
84    Abort,
85    JoinError(JoinError),
86    Failure(&'a (dyn std::fmt::Debug + Sync + Send)),
87}
88
89/// Builder to create a new task
90pub struct TaskBuilder {
91    /// Id of the task to build
92    id: usize,
93    /// Name of the task to build
94    name: Cow<'static, str>,
95    /// Run token of to use for the task
96    run_token: RunToken,
97    /// Stop the application stop if the task fails
98    critical: bool,
99    /// Stop the application stop if the task finishes
100    main: bool,
101    /// Stop the task by dropping the future instead of cancelling the [RunToken]
102    abort: bool,
103    /// Do to shutdown the task when the application is shutting down
104    no_shutdown: bool,
105    /// Shut down the task by this priority
106    shutdown_order: i32,
107}
108
109impl TaskBuilder {
110    /// Start the construction of a new task with the given name
111    pub fn new(name: impl Into<Cow<'static, str>>) -> Self {
112        Self {
113            id: TASK_ID_COUNT.fetch_add(1, Ordering::SeqCst),
114            name: name.into(),
115            run_token: Default::default(),
116            critical: false,
117            main: false,
118            abort: false,
119            no_shutdown: false,
120            shutdown_order: 0,
121        }
122    }
123
124    /// Unique id of the task we are creating
125    pub fn id(&self) -> usize {
126        self.id
127    }
128
129    /// Set the run_token for the task. It is sometimes nessesary to
130    /// know the run_token of a task before it is created
131    pub fn set_run_token(self, run_token: RunToken) -> Self {
132        Self { run_token, ..self }
133    }
134
135    /// If the task fails the whole application should be stopped
136    pub fn critical(self) -> Self {
137        Self {
138            critical: true,
139            ..self
140        }
141    }
142
143    /// If the task stops, the whole application should be stopped
144    pub fn main(self) -> Self {
145        Self { main: true, ..self }
146    }
147
148    /// Cancel the task by dropping the future, instead of only setting the cancel token
149    pub fn abort(self) -> Self {
150        Self {
151            abort: true,
152            ..self
153        }
154    }
155
156    /// Keep the task running on shutdown
157    pub fn no_shutdown(self) -> Self {
158        Self {
159            no_shutdown: true,
160            ..self
161        }
162    }
163
164    /// Tasks with a lower shutdown order are stopped earlier on shutdown
165    pub fn shutdown_order(self, shutdown_order: i32) -> Self {
166        Self {
167            shutdown_order,
168            ..self
169        }
170    }
171
172    /// Create the new task
173    pub fn create<
174        T: 'static + Send + Sync,
175        E: std::fmt::Debug + Sync + Send + 'static,
176        Fu: Future<Output = Result<T, E>> + Send + 'static,
177        F: FnOnce(RunToken) -> Fu,
178    >(
179        self,
180        fun: F,
181    ) -> Arc<Task<T, E>> {
182        let fut = fun(self.run_token.clone());
183        let id = self.id;
184        //Lock here so we do not try to remove before inserting
185        let mut tasks = TASKS.lock().unwrap();
186        debug!("Started task {} ({})", self.name, id);
187        let join_handle = tokio::spawn(async move {
188            let g = scope_guard(|| {
189                if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
190                    t._internal_handle_finished(FinishState::Drop);
191                }
192            });
193            let r = fut.await;
194            let s = match &r {
195                Ok(_) => FinishState::Success,
196                Err(e) => FinishState::Failure(e),
197            };
198            g.release();
199            if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
200                t._internal_handle_finished(s);
201            }
202            r
203        });
204        let task = Arc::new(Task {
205            id: self.id,
206            name: self.name,
207            critical: self.critical,
208            main: self.main,
209            abort: self.abort,
210            no_shutdown: self.no_shutdown,
211            shutdown_order: self.shutdown_order,
212            run_token: self.run_token,
213            start_time: std::time::SystemTime::now()
214                .duration_since(std::time::UNIX_EPOCH)
215                .unwrap()
216                .as_secs_f64(),
217            join_handle: Mutex::new(Some(join_handle)),
218        });
219        tasks.get_or_insert_default().insert(self.id, task.clone());
220        task
221    }
222
223    /// Create the new task also giving it a clean lock token
224    #[cfg(feature = "ordered-locks")]
225    pub fn create_with_lock_token<
226        T: 'static + Send + Sync,
227        E: std::fmt::Debug + Sync + Send + 'static,
228        Fu: Future<Output = Result<T, E>> + Send + 'static,
229        F: FnOnce(RunToken, CleanLockToken) -> Fu,
230    >(
231        self,
232        fun: F,
233    ) -> Arc<Task<T, E>> {
234        // Safety: We do not hold any lock in the task context
235        self.create(|run_token| fun(run_token, unsafe { CleanLockToken::new() }))
236    }
237}
238
239/// Base trait for all tasks, that is independent of the return type
240pub trait TaskBase: Send + Sync {
241    #[doc(hidden)]
242    fn _internal_handle_finished(&self, state: FinishState);
243    /// Return the shutdown order of this task as defined by the [TaskBuilder]
244    fn shutdown_order(&self) -> i32;
245    /// Return the name of this task as defined by the [TaskBuilder]
246    fn name(&self) -> &str;
247    /// Return the unique id of this task
248    fn id(&self) -> usize;
249    /// If true the application will shut down with an error if this task returns
250    fn main(&self) -> bool;
251    /// If this is true the task will be cancled by dropping the future instead of signaling the run token
252    fn abort(&self) -> bool;
253    /// If true the application will shut down with an error if this task returns with an error
254    fn critical(&self) -> bool;
255    /// Unixtimestamp of when the task started
256    fn start_time(&self) -> f64;
257    /// Cantle the task, return futer that returns when the task is done
258    fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()>;
259    /// Get the run token associated with the task
260    fn run_token(&self) -> &RunToken;
261    /// Do not stop task on shutdown
262    fn no_shutdown(&self) -> bool;
263}
264
265/// A possible running task, with a return value of `Result<T, E>`
266pub struct Task<T: Send + Sync, E: Sync + Sync> {
267    /// The unique id of the task
268    id: usize,
269    /// The name of the task
270    name: Cow<'static, str>,
271    /// Stop the application if this task fails
272    critical: bool,
273    /// Stop the application if this task finishes
274    main: bool,
275    /// std::mem::drop the future of the task when shutting down, instead of cancelling the [RunToken]
276    abort: bool,
277    /// Do not wait for the task to shut down when shutting down
278    no_shutdown: bool,
279    /// Order in which the task should shut down
280    shutdown_order: i32,
281    /// Run token associated with the task
282    run_token: RunToken,
283    /// Start time of the task as unix time stamp
284    start_time: f64,
285    /// Join handle for the task
286    join_handle: Mutex<Option<JoinHandle<Result<T, E>>>>,
287}
288
289impl<T: Send + Sync + 'static, E: Send + Sync + 'static> TaskBase for Task<T, E> {
290    fn shutdown_order(&self) -> i32 {
291        self.shutdown_order
292    }
293
294    fn name(&self) -> &str {
295        self.name.as_ref()
296    }
297
298    fn id(&self) -> usize {
299        self.id
300    }
301
302    fn _internal_handle_finished(&self, state: FinishState) {
303        match state {
304            FinishState::Success => {
305                if !self.main
306                    || !shutdown(format!(
307                        "Main task {} ({}) finished unexpected",
308                        self.name, self.id
309                    ))
310                {
311                    debug!("Finished task {} ({})", self.name, self.id);
312                }
313            }
314            FinishState::Drop => {
315                if self.main || self.critical {
316                    if shutdown(format!("Critical task {} ({}) dropped", self.name, self.id)) {
317                    } else if !self.abort {
318                        // Task was dropped, but it is not allowed to be dropped
319                        error!("Critical task {} ({}) dropped", self.name, self.id);
320                    } else {
321                        debug!("Critical task {} ({}) dropped", self.name, self.id)
322                    }
323                } else if !self.abort {
324                    // Task was dropped, but it is not allowed to be dropped
325                    error!("Task {} ({}) dropped", self.name, self.id);
326                } else {
327                    debug!("Task {} ({}) dropped", self.name, self.id)
328                }
329            }
330            FinishState::JoinError(e) => {
331                if (!self.main && !self.critical)
332                    || !shutdown(format!(
333                        "Join error in critical task {} ({}): {:?}",
334                        self.name, self.id, e
335                    ))
336                {
337                    error!("Join error in task {} ({}): {:?}", self.name, self.id, e);
338                }
339            }
340            FinishState::Failure(e) => {
341                if (!self.main && !self.critical)
342                    || !shutdown(format!(
343                        "Failure in critical task {} ({}) @ {:?}: {:?}",
344                        self.name,
345                        self.id,
346                        self.run_token().location(),
347                        e
348                    ))
349                {
350                    let location = self.run_token().location();
351                    error!(
352                        "Failure in task {} ({}) @ {:?}: {:?}",
353                        self.name, self.id, location, e
354                    );
355                }
356            }
357            FinishState::Abort => {
358                if !self.main
359                    || !shutdown(format!(
360                        "Main task {} ({}) aborted unexpected",
361                        self.name, self.id
362                    ))
363                {
364                    debug!("Aborted task {} ({})", self.name, self.id);
365                }
366            }
367        }
368    }
369
370    fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()> {
371        Box::pin(self.cancel())
372    }
373
374    fn main(&self) -> bool {
375        self.main
376    }
377
378    fn abort(&self) -> bool {
379        self.abort
380    }
381
382    fn critical(&self) -> bool {
383        self.critical
384    }
385
386    fn start_time(&self) -> f64 {
387        self.start_time
388    }
389
390    fn run_token(&self) -> &RunToken {
391        &self.run_token
392    }
393
394    fn no_shutdown(&self) -> bool {
395        self.no_shutdown
396    }
397}
398
399/// Error return while waiting for a task
400#[derive(Debug)]
401pub enum WaitError<E: Send + Sync> {
402    /// The task has allready been sucessfully awaited
403    HandleUnset(String),
404    /// A join error happened while waiting for the task
405    JoinError(tokio::task::JoinError),
406    /// The task failed with error E
407    TaskFailure(E),
408}
409
410impl<E: std::fmt::Display + Send + Sync> std::fmt::Display for WaitError<E> {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        match self {
413            WaitError::HandleUnset(v) => write!(f, "Handle unset: {v}"),
414            WaitError::JoinError(v) => write!(f, "Join Error: {v}"),
415            WaitError::TaskFailure(v) => write!(f, "Task Failure: {v}"),
416        }
417    }
418}
419
420impl<E: std::error::Error + Send + Sync> std::error::Error for WaitError<E> {}
421
422/// Borrow join_handle from task, put it back when dropped
423struct TaskJoinHandleBorrow<'a, T: Send + Sync, E: Send + Sync> {
424    /// The task we have borrowed the join handle from
425    task: &'a Arc<Task<T, E>>,
426    /// The borrowed join handle
427    jh: Option<JoinHandle<Result<T, E>>>,
428}
429
430impl<'a, T: Send + Sync, E: Send + Sync> TaskJoinHandleBorrow<'a, T, E> {
431    /// Borrow the join handle from the task until I am dropped
432    fn new(task: &'a Arc<Task<T, E>>) -> Self {
433        let jh = task.join_handle.lock().unwrap().take();
434        Self { task, jh }
435    }
436}
437
438impl<'a, T: Send + Sync, E: Send + Sync> Drop for TaskJoinHandleBorrow<'a, T, E> {
439    fn drop(&mut self) {
440        *self.task.join_handle.lock().unwrap() = self.jh.take();
441    }
442}
443
444impl<T: Send + Sync, E: Send + Sync> Task<T, E> {
445    /// Cancel the task, either by setting the cancel_token or by aborting it.
446    /// Wait for it to finish
447    /// Note that this function fill fail t
448    pub async fn cancel(self: Arc<Self>) {
449        let mut b = TaskJoinHandleBorrow::new(&self);
450        self.run_token.cancel();
451        if let Some(jh) = &mut b.jh {
452            if self.abort {
453                jh.abort();
454                let _ = jh.await;
455                if let Some(t) = TASKS
456                    .lock()
457                    .unwrap()
458                    .get_or_insert_default()
459                    .remove(&self.id)
460                {
461                    t._internal_handle_finished(FinishState::Abort);
462                }
463            } else if let Err(e) = jh.await {
464                info!("Unable to join task {e:?}");
465                if let Some(t) = TASKS
466                    .lock()
467                    .unwrap()
468                    .get_or_insert_default()
469                    .remove(&self.id)
470                {
471                    t._internal_handle_finished(FinishState::JoinError(e));
472                }
473            }
474        }
475        if !SHUTTING_DOWN.load(Ordering::SeqCst) {
476            info!("  canceled {} ({})", self.name, self.id);
477        }
478        std::mem::forget(b);
479    }
480
481    /// Wait for the task to finish.
482    pub async fn wait(self: Arc<Self>) -> Result<T, WaitError<E>> {
483        let mut b = TaskJoinHandleBorrow::new(&self);
484        let r = match &mut b.jh {
485            None => Err(WaitError::HandleUnset(self.name.to_string())),
486            Some(jh) => match jh.await {
487                Ok(Ok(v)) => Ok(v),
488                Ok(Err(e)) => Err(WaitError::TaskFailure(e)),
489                Err(e) => Err(WaitError::JoinError(e)),
490            },
491        };
492        std::mem::forget(b);
493        r
494    }
495}
496
497/// Future to wait for all futures in the vec to finish
498struct WaitTasks<'a, Sleep, Fut>(Sleep, &'a mut Vec<(String, usize, Fut, RunToken)>);
499impl<'a, Sleep: Unpin, Fut: Unpin> Unpin for WaitTasks<'a, Sleep, Fut> {}
500impl<'a, Sleep: Future + Unpin, Fut: Future + Unpin> Future for WaitTasks<'a, Sleep, Fut> {
501    type Output = bool;
502
503    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<bool> {
504        if self.0.poll_unpin(cx).is_ready() {
505            return Poll::Ready(false);
506        }
507
508        self.1
509            .retain_mut(|(_, _, f, _)| !matches!(f.poll_unpin(cx), Poll::Ready(_)));
510
511        if self.1.is_empty() {
512            Poll::Ready(true)
513        } else {
514            Poll::Pending
515        }
516    }
517}
518
519/// Cancel all tasks in shutdown order
520pub fn shutdown(message: String) -> bool {
521    if SHUTTING_DOWN
522        .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
523        .is_err()
524    {
525        // Already in the process of shutting down
526        return false;
527    }
528    info!("Shutting down: {message}");
529    tokio::spawn(async move {
530        let mut shutdown_tasks: Vec<Arc<dyn TaskBase>> = Vec::new();
531        loop {
532            for (_, task) in TASKS.lock().unwrap().get_or_insert_default().iter() {
533                if task.no_shutdown() {
534                    continue;
535                }
536                if let Some(t) = shutdown_tasks.first() {
537                    if t.shutdown_order() < task.shutdown_order() {
538                        continue;
539                    }
540                    if t.shutdown_order() > task.shutdown_order() {
541                        shutdown_tasks.clear();
542                    }
543                }
544                shutdown_tasks.push(task.clone());
545            }
546            if shutdown_tasks.is_empty() {
547                break;
548            }
549            info!(
550                "shutting down {} tasks with order {}",
551                shutdown_tasks.len(),
552                shutdown_tasks[0].shutdown_order()
553            );
554            let mut stop_futures: Vec<(String, usize, _, RunToken)> = shutdown_tasks
555                .iter()
556                .map(|t| {
557                    (
558                        t.name().to_string(),
559                        t.id(),
560                        t.clone().cancel(),
561                        t.run_token().clone(),
562                    )
563                })
564                .collect();
565            while !WaitTasks(
566                Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(30))),
567                &mut stop_futures,
568            )
569            .await
570            {
571                info!("still waiting for {} tasks", stop_futures.len(),);
572                for (name, id, _, rt) in &stop_futures {
573                    if let Some((file, line)) = rt.location() {
574                        info!("  {name} ({id}) at {file}:{line}");
575                    } else {
576                        info!("  {name} ({id})");
577                    }
578                }
579            }
580            shutdown_tasks.clear();
581        }
582        info!("shutdown done");
583        SHUTDOWN_NOTIFY.notify_waiters();
584    });
585    true
586}
587
588/// Wait until all tasks are done or shutdown has been called
589pub async fn run_tasks() {
590    SHUTDOWN_NOTIFY.notified().await
591}
592
593/// Return a list of all currently running tasks
594pub fn list_tasks() -> Vec<Arc<dyn TaskBase>> {
595    TASKS
596        .lock()
597        .unwrap()
598        .get_or_insert_default()
599        .values()
600        .cloned()
601        .collect()
602}
603
604/// Try to return a list of all currently running tasks,
605/// if we cannot acquire the lock for the tasks before duration has passed return None
606pub fn try_list_tasks_for(duration: std::time::Duration) -> Option<Vec<Arc<dyn TaskBase>>> {
607    let tries = 50;
608    for _ in 0..tries {
609        if let Ok(mut tasks) = TASKS.try_lock() {
610            return Some(tasks.get_or_insert_default().values().cloned().collect());
611        }
612        std::thread::sleep(duration / tries);
613    }
614    if let Ok(mut tasks) = TASKS.try_lock() {
615        return Some(tasks.get_or_insert_default().values().cloned().collect());
616    }
617    None
618}