1use std::error::Error;
2use std::fmt::{self, Debug, Display};
3use std::future::Future;
4use std::mem;
5use std::pin::Pin;
6use std::sync::atomic::Ordering;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10use futures_util::future::FusedFuture;
11use pin_project_lite::pin_project;
12use tokio::sync::futures::Notified;
13
14use crate::tasker::Shared;
15
16#[derive(Debug)]
20pub struct Stopped;
21
22impl Display for Stopped {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 write!(f, "The task was stopped")
25 }
26}
27
28impl Error for Stopped {
29 fn description(&self) -> &str {
30 "The task was stopped"
31 }
32}
33
34pin_project! {
35 pub struct Stopper {
44 #[pin] notified: Option<Pin<Box<Notified<'static>>>>,
46 shared: Pin<Arc<Shared>>,
47 }
48}
49
50impl Stopper {
51 pub(crate) fn new(shared: &Pin<Arc<Shared>>) -> Self {
52 let notified = if shared.stopped.load(Ordering::SeqCst) {
53 None
54 } else {
55 let notified = shared.notify_stop.notified();
56 let notified: Notified<'static> = unsafe { mem::transmute(notified) };
58 Some(Box::pin(notified))
59 };
60
61 Self {
62 shared: shared.clone(),
63 notified,
64 }
65 }
66
67 pub fn is_stopped(&self) -> bool {
69 self.shared.stopped.load(Ordering::SeqCst)
70 }
71
72 pub fn ok_or_stopped<T>(&self, value: T) -> Result<T, Stopped> {
76 if self.is_stopped() {
77 Err(Stopped)
78 } else {
79 Ok(value)
80 }
81 }
82}
83
84impl Future for Stopper {
85 type Output = Stopped;
86
87 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88 if self.shared.stopped.load(Ordering::SeqCst) {
89 return Poll::Ready(Stopped);
90 }
91
92 let this = self.project();
93 match this.notified.as_pin_mut() {
94 Some(notified) => notified.poll(cx).map(|_| Stopped),
95 None => Poll::Ready(Stopped),
96 }
97 }
98}
99
100impl FusedFuture for Stopper {
101 fn is_terminated(&self) -> bool {
102 self.is_stopped()
103 }
104}
105
106impl Clone for Stopper {
107 fn clone(&self) -> Self {
108 Self::new(&self.shared)
109 }
110}
111
112impl fmt::Debug for Stopper {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 f.debug_struct("Stopper")
115 .field("tasker", &self.shared.ptr())
116 .field("stopped", &self.is_stopped())
117 .finish()
118 }
119}