1use once_cell::sync::Lazy;
28use tokio::sync::mpsc;
29use tokio::sync::watch;
30use tokio::task::JoinHandle;
31
32use std::collections::HashMap;
33use std::future::Future;
34use std::sync::atomic::AtomicU64;
35use std::sync::atomic::Ordering;
36
37enum ActiveTaskOp {
38 Add { id: u64, handle: JoinHandle<()> },
39 Remove { id: u64 },
40}
41
42struct RemoveOnDrop {
45 id: u64,
46 task_tx_weak: mpsc::WeakUnboundedSender<ActiveTaskOp>,
47}
48impl Drop for RemoveOnDrop {
49 fn drop(&mut self) {
50 if let Some(tx) = self.task_tx_weak.upgrade() {
51 let _ = tx.send(ActiveTaskOp::Remove { id: self.id });
52 }
53 }
54}
55
56struct TaskKillswitch {
61 task_tx: parking_lot::RwLock<Option<mpsc::UnboundedSender<ActiveTaskOp>>>,
64 task_counter: AtomicU64,
65 all_killed: watch::Receiver<()>,
66}
67
68impl TaskKillswitch {
69 fn new() -> Self {
70 let (task_tx, task_rx) = mpsc::unbounded_channel();
71 let (signal_killed, all_killed) = watch::channel(());
72
73 let active_tasks = ActiveTasks {
74 task_rx,
75 tasks: Default::default(),
76 signal_killed,
77 };
78 tokio::spawn(active_tasks.collect());
79
80 Self {
81 task_tx: parking_lot::RwLock::new(Some(task_tx)),
82 task_counter: Default::default(),
83 all_killed,
84 }
85 }
86
87 fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
88 let Some(task_tx) = self.task_tx.read().as_ref().cloned() else {
92 return;
93 };
94
95 let id = self.task_counter.fetch_add(1, Ordering::SeqCst);
96 let task_tx_weak = task_tx.downgrade();
97
98 let handle = tokio::spawn(async move {
99 let _guard = RemoveOnDrop { task_tx_weak, id };
103 fut.await;
104 });
105
106 let _ = task_tx.send(ActiveTaskOp::Add { id, handle });
107 }
108
109 fn activate(&self) {
110 assert!(
115 self.task_tx.write().take().is_some(),
116 "killswitch can't be used twice"
117 );
118 }
119
120 fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
121 let mut signal = self.all_killed.clone();
122 async move {
123 let _ = signal.changed().await;
124 }
125 }
126}
127
128struct ActiveTasks {
129 task_rx: mpsc::UnboundedReceiver<ActiveTaskOp>,
130 tasks: HashMap<u64, JoinHandle<()>>,
131 signal_killed: watch::Sender<()>,
132}
133
134impl ActiveTasks {
135 async fn collect(mut self) {
136 while let Some(op) = self.task_rx.recv().await {
137 self.handle_task_op(op);
138 }
139
140 for task in self.tasks.into_values() {
141 task.abort();
142 }
143 drop(self.signal_killed);
144 }
145
146 fn handle_task_op(&mut self, op: ActiveTaskOp) {
147 match op {
148 ActiveTaskOp::Add { id, handle } => {
149 self.tasks.insert(id, handle);
150 },
151 ActiveTaskOp::Remove { id } => {
152 self.tasks.remove(&id);
153 },
154 }
155 }
156}
157
158static TASK_KILLSWITCH: Lazy<TaskKillswitch> = Lazy::new(TaskKillswitch::new);
160
161#[inline]
166pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
167 TASK_KILLSWITCH.spawn_task(fut);
168}
169
170#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
171pub async fn activate() {
172 TASK_KILLSWITCH.activate()
173}
174
175#[inline]
181pub fn activate_now() {
182 TASK_KILLSWITCH.activate();
183}
184
185#[inline]
192pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
193 TASK_KILLSWITCH.killed()
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use futures_util::future;
200 use std::time::Duration;
201 use tokio::sync::oneshot;
202
203 struct TaskAbortSignal(Option<oneshot::Sender<()>>);
204
205 impl TaskAbortSignal {
206 fn new() -> (Self, oneshot::Receiver<()>) {
207 let (tx, rx) = oneshot::channel();
208
209 (Self(Some(tx)), rx)
210 }
211 }
212
213 impl Drop for TaskAbortSignal {
214 fn drop(&mut self) {
215 let _ = self.0.take().unwrap().send(());
216 }
217 }
218
219 fn start_test_tasks(
220 killswitch: &TaskKillswitch,
221 ) -> Vec<oneshot::Receiver<()>> {
222 (0..1000)
223 .map(|_| {
224 let (tx, rx) = TaskAbortSignal::new();
225
226 killswitch.spawn_task(async move {
227 tokio::time::sleep(tokio::time::Duration::from_secs(3600))
228 .await;
229 drop(tx);
230 });
231
232 rx
233 })
234 .collect()
235 }
236
237 #[tokio::test]
238 async fn activate_killswitch_early() {
239 let killswitch = TaskKillswitch::new();
240 let abort_signals = start_test_tasks(&killswitch);
241
242 killswitch.activate();
243
244 tokio::time::timeout(
245 Duration::from_secs(1),
246 future::join_all(abort_signals),
247 )
248 .await
249 .expect("tasks should be killed within given timeframe");
250 }
251
252 #[tokio::test]
253 async fn activate_killswitch_with_delay() {
254 let killswitch = TaskKillswitch::new();
255 let abort_signals = start_test_tasks(&killswitch);
256 let signal_handle = tokio::spawn(killswitch.killed());
257
258 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
260
261 assert!(!signal_handle.is_finished());
262 killswitch.activate();
263
264 tokio::time::timeout(
265 Duration::from_secs(1),
266 future::join_all(abort_signals),
267 )
268 .await
269 .expect("tasks should be killed within given timeframe");
270
271 tokio::time::timeout(Duration::from_secs(1), signal_handle)
272 .await
273 .expect("killed() signal should have resolved")
274 .expect("signal task should join successfully");
275 }
276}