1use dashmap::DashMap;
28use parking_lot::Mutex;
29use tokio::sync::watch;
30use tokio::task;
31use tokio::task::AbortHandle;
32
33use std::future::Future;
34use std::sync::atomic::AtomicBool;
35use std::sync::atomic::Ordering;
36use std::sync::LazyLock;
37
38struct RemoveOnDrop {
41 id: task::Id,
42 storage: &'static ActiveTasks,
43}
44impl Drop for RemoveOnDrop {
45 fn drop(&mut self) {
46 self.storage.remove_task(self.id);
47 }
48}
49
50struct TaskKillswitch {
56 activated: AtomicBool,
58 storage: &'static ActiveTasks,
59
60 all_killed: watch::Receiver<()>,
64 signal_killed: Mutex<Option<watch::Sender<()>>>,
69}
70
71impl TaskKillswitch {
72 fn new(storage: &'static ActiveTasks) -> Self {
73 let (signal_killed, all_killed) = watch::channel(());
74 let signal_killed = Mutex::new(Some(signal_killed));
75
76 Self {
77 activated: AtomicBool::new(false),
78 storage,
79 signal_killed,
80 all_killed,
81 }
82 }
83
84 fn with_leaked_storage() -> Self {
89 let storage = Box::leak(Box::new(ActiveTasks::default()));
90 Self::new(storage)
91 }
92
93 fn was_activated(&self) -> bool {
94 self.activated.load(Ordering::Relaxed)
97 }
98
99 fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
100 if self.was_activated() {
101 return;
102 }
103
104 let storage = self.storage;
105 let handle = tokio::spawn(async move {
106 let id = task::id();
107 let _guard = RemoveOnDrop { id, storage };
108 fut.await;
109 })
110 .abort_handle();
111
112 let res = self.storage.add_task_if(handle, || !self.was_activated());
113 if let Err(handle) = res {
114 handle.abort();
116 }
117 }
118
119 fn activate(&self) {
120 assert!(
125 !self.activated.swap(true, Ordering::Relaxed),
126 "killswitch can't be used twice"
127 );
128
129 let tasks = self.storage;
130 let signal_killed = self.signal_killed.lock().take();
131 std::thread::spawn(move || {
132 tasks.kill_all();
133 drop(signal_killed);
134 });
135 }
136
137 fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
138 let mut signal = self.all_killed.clone();
139 async move {
140 let _ = signal.changed().await;
141 }
142 }
143}
144
145enum TaskEntry {
146 Handle(AbortHandle),
148 Tombstone,
151}
152
153#[derive(Default)]
154struct ActiveTasks {
155 tasks: DashMap<task::Id, TaskEntry>,
156}
157
158impl ActiveTasks {
159 fn kill_all(&self) {
160 self.tasks.retain(|_, entry| {
161 if let TaskEntry::Handle(task) = entry {
162 task.abort();
163 }
164 false });
166 }
167
168 fn add_task_if(
169 &self, handle: AbortHandle, cond: impl FnOnce() -> bool,
170 ) -> Result<(), AbortHandle> {
171 use dashmap::Entry::*;
172 let id = handle.id();
173
174 match self.tasks.entry(id) {
175 Vacant(e) => {
176 if !cond() {
177 return Err(handle);
178 }
179 e.insert(TaskEntry::Handle(handle));
180 },
181 Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {
182 e.remove();
185 },
186 Occupied(_) => panic!("tokio task ID already in use: {id}"),
187 }
188
189 Ok(())
190 }
191
192 fn remove_task(&self, id: task::Id) {
193 use dashmap::Entry::*;
194 match self.tasks.entry(id) {
195 Vacant(e) => {
196 e.insert(TaskEntry::Tombstone);
198 },
199 Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {},
200 Occupied(e) => {
201 e.remove();
202 },
203 }
204 }
205}
206
207static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
209 LazyLock::new(TaskKillswitch::with_leaked_storage);
210
211#[inline]
216pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
217 TASK_KILLSWITCH.spawn_task(fut);
218}
219
220#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
221pub async fn activate() {
222 TASK_KILLSWITCH.activate()
223}
224
225#[inline]
231pub fn activate_now() {
232 TASK_KILLSWITCH.activate();
233}
234
235#[inline]
242pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
243 TASK_KILLSWITCH.killed()
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use futures_util::future;
250 use std::time::Duration;
251 use tokio::sync::oneshot;
252
253 struct TaskAbortSignal(Option<oneshot::Sender<()>>);
254
255 impl TaskAbortSignal {
256 fn new() -> (Self, oneshot::Receiver<()>) {
257 let (tx, rx) = oneshot::channel();
258
259 (Self(Some(tx)), rx)
260 }
261 }
262
263 impl Drop for TaskAbortSignal {
264 fn drop(&mut self) {
265 let _ = self.0.take().unwrap().send(());
266 }
267 }
268
269 fn start_test_tasks(
270 killswitch: &TaskKillswitch,
271 ) -> Vec<oneshot::Receiver<()>> {
272 (0..1000)
273 .map(|_| {
274 let (tx, rx) = TaskAbortSignal::new();
275
276 killswitch.spawn_task(async move {
277 tokio::time::sleep(tokio::time::Duration::from_secs(3600))
278 .await;
279 drop(tx);
280 });
281
282 rx
283 })
284 .collect()
285 }
286
287 #[tokio::test]
288 async fn activate_killswitch_early() {
289 let killswitch = TaskKillswitch::with_leaked_storage();
290 let abort_signals = start_test_tasks(&killswitch);
291
292 killswitch.activate();
293
294 tokio::time::timeout(
295 Duration::from_secs(1),
296 future::join_all(abort_signals),
297 )
298 .await
299 .expect("tasks should be killed within given timeframe");
300 }
301
302 #[tokio::test]
303 async fn activate_killswitch_with_delay() {
304 let killswitch = TaskKillswitch::with_leaked_storage();
305 let abort_signals = start_test_tasks(&killswitch);
306 let signal_handle = tokio::spawn(killswitch.killed());
307
308 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
310
311 assert!(!signal_handle.is_finished());
312 killswitch.activate();
313
314 tokio::time::timeout(
315 Duration::from_secs(1),
316 future::join_all(abort_signals),
317 )
318 .await
319 .expect("tasks should be killed within given timeframe");
320
321 tokio::time::timeout(Duration::from_secs(1), signal_handle)
322 .await
323 .expect("killed() signal should have resolved")
324 .expect("signal task should join successfully");
325 }
326}