tokio_tasks/
run_token.rs

1//! Defines a run token that can be used to cancel async tasks
2use futures_util::Future;
3
4use std::{
5    mem::MaybeUninit,
6    pin::Pin,
7    sync::{Arc, atomic::Ordering},
8    task::{Context, Poll, Waker},
9};
10
11use std::sync::atomic::AtomicPtr;
12#[cfg(feature = "runtoken-id")]
13use std::sync::atomic::AtomicU64;
14
15#[cfg(feature = "ordered-locks")]
16use ordered_locks::{L0, LockToken};
17
18/// Next id used for run token ids
19#[cfg(feature = "runtoken-id")]
20static IDC: AtomicU64 = AtomicU64::new(0);
21
22/// Intrusive circular linked list of T's
23pub struct IntrusiveList<T> {
24    /// Pointer to the first element in the list, the last element can be found
25    /// as first->prev
26    first: *mut ListNode<T>,
27}
28
29impl<T> Default for IntrusiveList<T> {
30    fn default() -> Self {
31        Self {
32            first: std::ptr::null_mut(),
33        }
34    }
35}
36
37impl<T> IntrusiveList<T> {
38    /// Add node to the list with the content v
39    ///
40    /// Safety:
41    /// - node should be a valid pointer.
42    /// - node should only be accessed by us
43    /// - node should not be in a list
44    /// - node should not have value added to it
45    ///
46    /// This will write a value to the node
47    unsafe fn push_back(&mut self, node: *mut ListNode<T>, v: T) {
48        // Safety: node is a valid pointer we may access
49        let n = unsafe { &mut *node };
50        assert!(n.next.is_null());
51        n.data.write(v);
52        if self.first.is_null() {
53            n.next = n;
54            n.prev = n;
55            self.first = n;
56        } else {
57            // Safety: first is not null and can be deref'ed by the invariants of the list
58            let f = unsafe { &mut *self.first };
59            n.prev = f.prev;
60            n.next = self.first;
61            // Safety: we have just set prev to a valid pointer
62            unsafe {
63                (*n.prev).next = node;
64            }
65            f.prev = node;
66        }
67    }
68
69    /// Remove node to the list returning its content
70    ///
71    /// Safety:
72    /// - node should be a valid pointer.
73    /// - node should only be accessed by us
74    /// - node should be in a list
75    ///
76    /// The value will be read out from the node
77    unsafe fn remove(&mut self, node: *mut ListNode<T>) -> T {
78        // Safety: node is a valid pointer we may access
79        let n = unsafe { &mut *node };
80        assert!(!n.next.is_null());
81        // Safety: We require that the node is in the list,
82        // which by invariant means that data is set
83        let v = unsafe { n.data.as_mut_ptr().read() };
84        if n.next == node {
85            self.first = std::ptr::null_mut();
86        } else {
87            if self.first == node {
88                self.first = n.next;
89            }
90            // Safety: By list invariants n.next is a valid pointer we may mutate
91            unsafe {
92                (*n.next).prev = n.prev;
93            }
94            // Safety: By list invariants n.prev is a valid pointer we may mutate
95            unsafe {
96                (*n.prev).next = n.next;
97            }
98        }
99        n.next = std::ptr::null_mut();
100        n.prev = std::ptr::null_mut();
101        v
102    }
103
104    /// Remove all entries from this list
105    fn drain(&mut self, v: impl Fn(T)) {
106        if self.first.is_null() {
107            return;
108        }
109        let mut cur = self.first;
110        loop {
111            // Safety: By invariant cur points to a list node that has not
112            // yet been removed, and we are allowed to mutate it
113            let c = unsafe { &mut *cur };
114            // Safety: Since c is in the list it has a data value
115            let d = unsafe { c.data.as_mut_ptr().read() };
116            v(d);
117            let next = c.next;
118            c.next = std::ptr::null_mut();
119            c.prev = std::ptr::null_mut();
120            if next == self.first {
121                break;
122            }
123            cur = next;
124        }
125        self.first = std::ptr::null_mut();
126    }
127
128    /// Check if the node is in a list
129    ///
130    /// Safety:
131    /// - node should be a valid pointer
132    /// - No one should modify node while we access it
133    unsafe fn in_list(&self, node: *mut ListNode<T>) -> bool {
134        // Safety: Node is a valid pointer
135        unsafe { !(*node).next.is_null() }
136    }
137}
138
139/// Node uned in the linked list
140pub struct ListNode<T> {
141    /// The previous element in the list
142    prev: *mut ListNode<T>,
143    /// The next element in the node
144    next: *mut ListNode<T>,
145    /// The data contained in this node
146    data: std::mem::MaybeUninit<T>,
147    /// Make sure we do not implement unpin
148    _pin: std::marker::PhantomPinned,
149}
150
151impl<T> Default for ListNode<T> {
152    fn default() -> Self {
153        Self {
154            prev: std::ptr::null_mut(),
155            next: std::ptr::null_mut(),
156            data: MaybeUninit::uninit(),
157            _pin: Default::default(),
158        }
159    }
160}
161
162/// The state a [RunToken] is in
163enum State {
164    /// The [RunToken] is running
165    Run,
166    /// The [RunToken] has been canceled
167    Cancel,
168    /// The task should paused
169    #[cfg(feature = "pause")]
170    Pause,
171}
172
173/// Inner content of a [RunToken] behind a Mutex
174struct Content {
175    /// The state of the run token
176    state: State,
177    /// Wakers to wake when cancelling the [RunToken]
178    cancel_wakers: IntrusiveList<Waker>,
179    /// Wakers to wake when unpausing the [RunToken]
180    run_wakers: IntrusiveList<Waker>,
181    /// User supplied data stored in runtoken
182    #[cfg(feature = "runtoken-user-data")]
183    user_data: Option<String>,
184}
185
186// Safety: the intrusive lists may be sent between threads
187unsafe impl Send for Content {}
188
189impl Content {
190    /// Wake waker when the [RunToken] is cancelled
191    ///
192    /// Safety:
193    /// - node must be a valid pointer
194    /// - node must not contain a value if it is not in a list
195    unsafe fn add_cancel_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
196        // Safety: We can check if the node is in the list since it is a valid pointer
197        // and we own mutation rights
198        let in_list = unsafe { self.cancel_wakers.in_list(node) };
199        if !in_list {
200            // Safety: Node is a valid pointer, we may mutate it,
201            // we have checked that it is not part of a list, and thus
202            // it has no value
203            unsafe { self.cancel_wakers.push_back(node, waker.clone()) }
204        }
205    }
206
207    /// Wake waker when the [RunToken] is unpaused
208    ///
209    /// Safety:
210    /// - node must be a valid pointer
211    /// - node must not contain a value if it is not in a list
212    #[cfg(feature = "pause")]
213    unsafe fn add_run_waker(&mut self, node: *mut ListNode<Waker>, waker: &Waker) {
214        // Safety: We can check if the node is in the list since it is a valid pointer
215        // and we own mutation rights
216        let in_list = unsafe { self.run_wakers.in_list(node) };
217        if !in_list {
218            // Safety: Node is a valid pointer, we may mutate it,
219            // we have checked that it is not part of a list, and thus
220            // it has no value
221            unsafe { self.run_wakers.push_back(node, waker.clone()) }
222        }
223    }
224
225    /// Remove node from the list of nodes to be woken when the run token is
226    /// cancelled
227    ///
228    /// Safety:
229    /// - node must be a valid pointer
230    /// - if the node is in a list, it mut be the cancel_wakers list
231    unsafe fn remove_cancel_waker(&mut self, node: *mut ListNode<Waker>) {
232        // Safety: We can check if the node is in the list since it is a valid pointer
233        // and we own mutation rights
234        let in_list = unsafe { self.cancel_wakers.in_list(node) };
235        if in_list {
236            // Safety: Node is a valid pointer, we may mutate it and it is in the list
237            unsafe { self.cancel_wakers.remove(node) };
238        }
239    }
240
241    /// Remove node from the list of nodes to be woken when the run token is
242    /// unpaused
243    ///
244    /// Safety:
245    /// - node must be a valid pointer
246    /// - if the node is in a list, it mut be the run_wakers list
247    #[cfg(feature = "pause")]
248    unsafe fn remove_run_waker(&mut self, node: *mut ListNode<Waker>) {
249        // Safety: We can check if the node is in the list since it is a valid pointer
250        // and we own mutation rights
251        let in_list = unsafe { self.run_wakers.in_list(node) };
252        if in_list {
253            // Safety: Node is a valid pointer, we may mutate it and it is in the list
254            unsafe { self.run_wakers.remove(node) };
255        }
256    }
257}
258
259/// Inner content of a [RunToken] not behind a mutex
260struct Inner {
261    /// Condition notified on cancel and unpause
262    cond: std::sync::Condvar,
263    /// Inner content of this [RunToken] that must be accessed exclusively
264    content: std::sync::Mutex<Content>,
265    /// The id unique of this run token
266    #[cfg(feature = "runtoken-id")]
267    id: u64,
268    /// The location last set on this run-token, mut be a valid pointer to a
269    /// &' static str of the form "file:line" or null
270    location_file_line: AtomicPtr<u8>,
271}
272/// Similar to a [`tokio_util::sync::CancellationToken`],
273/// the RunToken encapsulates the possibility of canceling an async command.
274/// However it also allows pausing and resuming the async command,
275/// and it is possible to wait in both a blocking fashion and an asynchronous fashion.
276#[derive(Clone)]
277pub struct RunToken(Arc<Inner>);
278
279impl RunToken {
280    /// Construct a new paused run token
281    #[cfg(feature = "pause")]
282    pub fn new_paused() -> Self {
283        Self(Arc::new(Inner {
284            cond: std::sync::Condvar::new(),
285            content: std::sync::Mutex::new(Content {
286                state: State::Pause,
287                cancel_wakers: Default::default(),
288                run_wakers: Default::default(),
289                #[cfg(feature = "runtoken-user-data")]
290                user_data: None,
291            }),
292            location_file_line: Default::default(),
293            #[cfg(feature = "runtoken-id")]
294            id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
295        }))
296    }
297
298    /// Construct a new running run token
299    pub fn new() -> Self {
300        Self(Arc::new(Inner {
301            cond: std::sync::Condvar::new(),
302            content: std::sync::Mutex::new(Content {
303                state: State::Run,
304                cancel_wakers: Default::default(),
305                run_wakers: Default::default(),
306                #[cfg(feature = "runtoken-user-data")]
307                user_data: None,
308            }),
309            location_file_line: Default::default(),
310            #[cfg(feature = "runtoken-id")]
311            id: IDC.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
312        }))
313    }
314
315    /// Cancel computation
316    pub fn cancel(&self) {
317        let mut content = self.0.content.lock().unwrap();
318        if matches!(content.state, State::Cancel) {
319            return;
320        }
321        content.state = State::Cancel;
322
323        content.run_wakers.drain(|w| w.wake());
324        content.cancel_wakers.drain(|w| w.wake());
325        self.0.cond.notify_all();
326    }
327
328    /// Pause computation if we are running
329    #[cfg(feature = "pause")]
330    pub fn pause(&self) {
331        let mut content = self.0.content.lock().unwrap();
332        if !matches!(content.state, State::Run) {
333            return;
334        }
335        content.state = State::Pause;
336    }
337
338    /// Resume computation if we are paused
339    #[cfg(feature = "pause")]
340    pub fn resume(&self) {
341        let mut content = self.0.content.lock().unwrap();
342        if !matches!(content.state, State::Pause) {
343            return;
344        }
345        content.state = State::Run;
346        content.run_wakers.drain(|w| w.wake());
347        self.0.cond.notify_all();
348    }
349
350    /// Return true iff we are canceled
351    pub fn is_cancelled(&self) -> bool {
352        matches!(self.0.content.lock().unwrap().state, State::Cancel)
353    }
354
355    /// Return true iff we are paused
356    #[cfg(feature = "pause")]
357    pub fn is_paused(&self) -> bool {
358        matches!(self.0.content.lock().unwrap().state, State::Pause)
359    }
360
361    /// Return true iff we are runnig
362    #[cfg(feature = "pause")]
363    pub fn is_running(&self) -> bool {
364        matches!(self.0.content.lock().unwrap().state, State::Run)
365    }
366
367    /// Block the thread until we are not paused, and then return true if we are canceled or false if we are running
368    #[cfg(feature = "pause")]
369    pub fn wait_paused_check_cancelled_sync(&self) -> bool {
370        let mut content = self.0.content.lock().unwrap();
371        loop {
372            match &content.state {
373                State::Run => return false,
374                State::Cancel => return true,
375                State::Pause => {
376                    content = self.0.cond.wait(content).unwrap();
377                }
378            }
379        }
380    }
381
382    /// Suspend the async coroutine until we are not paused, and then return true if we are canceled or false if we are running
383    #[cfg(feature = "pause")]
384    pub fn wait_paused_check_cancelled(&self) -> WaitForPauseFuture<'_> {
385        WaitForPauseFuture {
386            token: self,
387            waker: Default::default(),
388        }
389    }
390
391    /// Suspend the async coroutine until cancel() is called
392    pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
393        WaitForCancellationFuture {
394            token: self,
395            waker: Default::default(),
396        }
397    }
398
399    /// Suspend the async coroutine until cancel() is called, checking that we do not hold
400    /// a too deep lock
401    #[cfg(feature = "ordered-locks")]
402    pub fn cancelled_checked(
403        &self,
404        _lock_token: LockToken<'_, L0>,
405    ) -> WaitForCancellationFuture<'_> {
406        WaitForCancellationFuture {
407            token: self,
408            waker: Default::default(),
409        }
410    }
411
412    /// Store a file line location in the run_token
413    /// The string must be on the form "file:line\0"
414    ///
415    /// You probably want to call the [set_location] macro instead
416    #[inline]
417    pub fn set_location_file_line(&self, file_line_str: &'static str) {
418        assert!(file_line_str.ends_with('\0'));
419        self.0
420            .location_file_line
421            .store(file_line_str.as_ptr() as *mut u8, Ordering::Relaxed);
422    }
423
424    /// Retrieve the stored file,live location in the run_token
425    pub fn location(&self) -> Option<(&'static str, u32)> {
426        let location_file_line = self.0.location_file_line.load(Ordering::Relaxed) as *const u8;
427        if location_file_line.is_null() {
428            return None;
429        }
430        let mut len = 0;
431        // Safety: let must be within the string since it is \0 terminated and we
432        // have not yet seen a \0
433        loop {
434            // Safety: let must be within the string since it is \0 terminated and we
435            // have not yet seen a \0
436            let l = unsafe { location_file_line.add(len) };
437            // Safety: let must be within the string since it is \0 terminated and we
438            // have not yet seen a \0
439            let c = unsafe { *l };
440            if c == b'\0' {
441                break;
442            }
443            len += 1;
444        }
445
446        // Safety: location_file_line points to a byte array with at least len bytes
447        let location_file_line = unsafe { std::slice::from_raw_parts(location_file_line, len) };
448
449        // Safety: location_file_line points to a utf-8 string ending with "\0"
450        let location_file_line = unsafe { std::str::from_utf8_unchecked(location_file_line) };
451
452        match location_file_line.rsplit_once(":") {
453            Some((file, line)) => match line.parse() {
454                Ok(v) => Some((file, v)),
455                Err(_) => Some((location_file_line, 0)),
456            },
457            None => Some((location_file_line, 0)),
458        }
459    }
460
461    #[cfg(feature = "runtoken-id")]
462    /// The unique incremental id of this run token
463    #[inline]
464    pub fn id(&self) -> u64 {
465        self.0.id
466    }
467
468    #[cfg(feature = "runtoken-user-data")]
469    /// Store user data in the run token
470    pub fn set_user_data(&self, data: Option<String>) {
471        self.0.content.lock().unwrap().user_data = data;
472    }
473
474    #[cfg(feature = "runtoken-user-data")]
475    /// Get user data stored in run_token
476    pub fn user_data(&self) -> Option<String> {
477        self.0.content.lock().unwrap().user_data.clone()
478    }
479}
480
481/// Update the location stored in a run token to the current file:line
482#[macro_export]
483macro_rules! set_location {
484    ($run_token: expr) => {
485        $run_token.set_location_file_line(concat!(file!(), ":", line!(), "\0"));
486    };
487}
488
489impl Default for RunToken {
490    fn default() -> Self {
491        Self::new()
492    }
493}
494
495impl core::fmt::Debug for RunToken {
496    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
497        let mut d = f.debug_tuple("RunToken");
498        match self.0.content.lock().unwrap().state {
499            State::Run => d.field(&"Running"),
500            State::Cancel => d.field(&"Canceled"),
501            #[cfg(feature = "pause")]
502            State::Pause => d.field(&"Paused"),
503        };
504        d.finish()
505    }
506}
507
508/// Wait until task cancellation is completed
509///
510/// Note: [std::mem::forget]ting this future may crash your application,
511/// a pointer to the content of this future is stored in the RunToken
512#[must_use = "futures do nothing unless polled"]
513pub struct WaitForCancellationFuture<'a> {
514    /// The token to wait for
515    token: &'a RunToken,
516    /// Entry in the cancel_wakers list of the RunToken
517    waker: ListNode<Waker>,
518}
519
520impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
521    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
522        f.debug_struct("WaitForCancellationFuture").finish()
523    }
524}
525
526impl<'a> Future for WaitForCancellationFuture<'a> {
527    type Output = ();
528
529    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
530        let mut content = self.token.0.content.lock().unwrap();
531        match content.state {
532            State::Cancel => Poll::Ready(()),
533            State::Run => {
534                // Safety: We do not move the node
535                let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
536                // Safety: The node is either not in the list and has no value, or it is
537                // in the list and has a value
538                unsafe { content.add_cancel_waker(node, cx.waker()) };
539                Poll::Pending
540            }
541            #[cfg(feature = "pause")]
542            State::Pause => {
543                // Safety: We do not move the node
544                let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
545                // Safety: The node is either not in the list and has no value, or it is
546                // in the list and has a value
547                unsafe { content.add_cancel_waker(node, cx.waker()) };
548                Poll::Pending
549            }
550        }
551    }
552}
553
554impl<'a> Drop for WaitForCancellationFuture<'a> {
555    fn drop(&mut self) {
556        // Safety: The node is valid
557        unsafe {
558            self.token
559                .0
560                .content
561                .lock()
562                .unwrap()
563                .remove_cancel_waker(&mut self.waker);
564        }
565    }
566}
567
568// Safety: We can safely move WaitForCancellationFuture when it is not pinned
569unsafe impl<'a> Send for WaitForCancellationFuture<'a> {}
570
571#[cfg(feature = "pause")]
572/// Wait until task is not paused
573///
574/// Note: [std::mem::forget]ting this future may crash your application,
575/// a pointer to the content of this future is stored in the RunToken
576#[must_use = "futures do nothing unless polled"]
577pub struct WaitForPauseFuture<'a> {
578    /// The run toke to wait to unpause
579    token: &'a RunToken,
580    /// Entry in the run_wakers list of the RunToken
581    waker: ListNode<Waker>,
582}
583
584#[cfg(feature = "pause")]
585impl<'a> core::fmt::Debug for WaitForPauseFuture<'a> {
586    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
587        f.debug_struct("WaitForPauseFuture").finish()
588    }
589}
590
591#[cfg(feature = "pause")]
592impl<'a> Future for WaitForPauseFuture<'a> {
593    type Output = bool;
594
595    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<bool> {
596        let mut content = self.token.0.content.lock().unwrap();
597        match content.state {
598            State::Cancel => Poll::Ready(true),
599            State::Run => Poll::Ready(false),
600            State::Pause => {
601                // Safety: We will not move the node
602                let node = unsafe { &mut Pin::get_unchecked_mut(self).waker };
603                // Safety: The node is a valid pointer
604                unsafe { content.add_run_waker(node, cx.waker()) };
605                Poll::Pending
606            }
607        }
608    }
609}
610
611#[cfg(feature = "pause")]
612impl<'a> Drop for WaitForPauseFuture<'a> {
613    fn drop(&mut self) {
614        // Safety: The node is valid
615        unsafe {
616            self.token
617                .0
618                .content
619                .lock()
620                .unwrap()
621                .remove_run_waker(&mut self.waker);
622        }
623    }
624}
625
626#[cfg(feature = "pause")]
627// Safety: We can safely move WaitForPauseFuture when it is not pinned
628unsafe impl<'a> Send for WaitForPauseFuture<'a> {}