task_killswitch/
lib.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use dashmap::DashMap;
28use parking_lot::Mutex;
29use tokio::sync::watch;
30use tokio::task;
31use tokio::task::AbortHandle;
32
33use std::future::Future;
34use std::sync::atomic::AtomicBool;
35use std::sync::atomic::Ordering;
36use std::sync::LazyLock;
37
38/// Drop guard for task removal. If a task panics, this makes sure
39/// it is removed from [`ActiveTasks`] properly.
40struct RemoveOnDrop {
41    id: task::Id,
42    storage: &'static ActiveTasks,
43}
44impl Drop for RemoveOnDrop {
45    fn drop(&mut self) {
46        self.storage.remove_task(self.id);
47    }
48}
49
50/// A task killswitch that allows aborting all the tasks spawned with it at
51/// once. The implementation strives to minimize in-band locking. Spawning a
52/// future requires a single sharded lock from an internal [`DashMap`].
53/// Conflicts are expected to be very rare (dashmap defaults to `4 * nproc`
54/// shards, while each thread can only spawn one task at a time.)
55struct TaskKillswitch {
56    // Invariant: If `activated` is true, we don't add new tasks anymore.
57    activated: AtomicBool,
58    storage: &'static ActiveTasks,
59
60    /// Watcher that is triggered after all kill signals have been sent (by
61    /// dropping `signal_killed`.) Currently-running tasks are killed after
62    /// their next yield, which may be after this triggers.
63    all_killed: watch::Receiver<()>,
64    // NOTE: All we want here is to take ownership of `signal_killed` when
65    // activating the killswitch. That code path only runs once per instance, but
66    // requires interior mutability. Using `Mutex` is easier than bothering with
67    // an `UnsafeCell`. The mutex is guaranteed to be unlocked.
68    signal_killed: Mutex<Option<watch::Sender<()>>>,
69}
70
71impl TaskKillswitch {
72    fn new(storage: &'static ActiveTasks) -> Self {
73        let (signal_killed, all_killed) = watch::channel(());
74        let signal_killed = Mutex::new(Some(signal_killed));
75
76        Self {
77            activated: AtomicBool::new(false),
78            storage,
79            signal_killed,
80            all_killed,
81        }
82    }
83
84    /// Creates a killswitch by allocating and leaking the task storage.
85    ///
86    /// **NOTE:** This is intended for use in `static`s and tests. It should not
87    /// be exposed publicly!
88    fn with_leaked_storage() -> Self {
89        let storage = Box::leak(Box::new(ActiveTasks::default()));
90        Self::new(storage)
91    }
92
93    fn was_activated(&self) -> bool {
94        // All synchronization is done using locks,
95        // so we can use relaxed for our atomics.
96        self.activated.load(Ordering::Relaxed)
97    }
98
99    fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
100        if self.was_activated() {
101            return;
102        }
103
104        let storage = self.storage;
105        let handle = tokio::spawn(async move {
106            let id = task::id();
107            let _guard = RemoveOnDrop { id, storage };
108            fut.await;
109        })
110        .abort_handle();
111
112        let res = self.storage.add_task_if(handle, || !self.was_activated());
113        if let Err(handle) = res {
114            // Killswitch was activated by the time we got a lock on the map shard
115            handle.abort();
116        }
117    }
118
119    fn activate(&self) {
120        // We check `activated` after locking the map shard and before inserting
121        // an element. This ensures in-progress spawns either complete before
122        // `tasks.kill_all()` obtains the lock for that shard, or they abort
123        // afterwards.
124        assert!(
125            !self.activated.swap(true, Ordering::Relaxed),
126            "killswitch can't be used twice"
127        );
128
129        let tasks = self.storage;
130        let signal_killed = self.signal_killed.lock().take();
131        std::thread::spawn(move || {
132            tasks.kill_all();
133            drop(signal_killed);
134        });
135    }
136
137    fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
138        let mut signal = self.all_killed.clone();
139        async move {
140            let _ = signal.changed().await;
141        }
142    }
143}
144
145enum TaskEntry {
146    /// Task was added and not yet removed.
147    Handle(AbortHandle),
148    /// Task was removed before it was added. This can happen if a spawned
149    /// future completes before the spawning thread can add it to the map.
150    Tombstone,
151}
152
153#[derive(Default)]
154struct ActiveTasks {
155    tasks: DashMap<task::Id, TaskEntry>,
156}
157
158impl ActiveTasks {
159    fn kill_all(&self) {
160        self.tasks.retain(|_, entry| {
161            if let TaskEntry::Handle(task) = entry {
162                task.abort();
163            }
164            false // remove all elements
165        });
166    }
167
168    fn add_task_if(
169        &self, handle: AbortHandle, cond: impl FnOnce() -> bool,
170    ) -> Result<(), AbortHandle> {
171        use dashmap::Entry::*;
172        let id = handle.id();
173
174        match self.tasks.entry(id) {
175            Vacant(e) => {
176                if !cond() {
177                    return Err(handle);
178                }
179                e.insert(TaskEntry::Handle(handle));
180            },
181            Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {
182                // Task was removed before it was added. Clear the map entry and
183                // drop the handle.
184                e.remove();
185            },
186            Occupied(_) => panic!("tokio task ID already in use: {id}"),
187        }
188
189        Ok(())
190    }
191
192    fn remove_task(&self, id: task::Id) {
193        use dashmap::Entry::*;
194        match self.tasks.entry(id) {
195            Vacant(e) => {
196                // Task was not added yet, set a tombstone instead.
197                e.insert(TaskEntry::Tombstone);
198            },
199            Occupied(e) if matches!(e.get(), TaskEntry::Tombstone) => {},
200            Occupied(e) => {
201                e.remove();
202            },
203        }
204    }
205}
206
207/// The global [`TaskKillswitch`] exposed publicly from the crate.
208static TASK_KILLSWITCH: LazyLock<TaskKillswitch> =
209    LazyLock::new(TaskKillswitch::with_leaked_storage);
210
211/// Spawns a new asynchronous task and registers it in the crate's global
212/// killswitch.
213///
214/// Under the hood, [`tokio::spawn`] schedules the actual execution.
215#[inline]
216pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
217    TASK_KILLSWITCH.spawn_task(fut);
218}
219
220#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
221pub async fn activate() {
222    TASK_KILLSWITCH.activate()
223}
224
225/// Triggers the killswitch, thereby scheduling all registered tasks to be
226/// killed.
227///
228/// Note: tasks are not killed synchronously in this function. This means
229/// `activate_now()` will return before all tasks have been stopped.
230#[inline]
231pub fn activate_now() {
232    TASK_KILLSWITCH.activate();
233}
234
235/// Returns a future that resolves when all registered tasks have been killed,
236/// after [`activate_now`] has been called.
237///
238/// Note: tokio does not kill a task until the next time it yields to the
239/// runtime. This means some killed tasks may still be running by the time this
240/// Future resolves.
241#[inline]
242pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
243    TASK_KILLSWITCH.killed()
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use futures_util::future;
250    use std::time::Duration;
251    use tokio::sync::oneshot;
252
253    struct TaskAbortSignal(Option<oneshot::Sender<()>>);
254
255    impl TaskAbortSignal {
256        fn new() -> (Self, oneshot::Receiver<()>) {
257            let (tx, rx) = oneshot::channel();
258
259            (Self(Some(tx)), rx)
260        }
261    }
262
263    impl Drop for TaskAbortSignal {
264        fn drop(&mut self) {
265            let _ = self.0.take().unwrap().send(());
266        }
267    }
268
269    fn start_test_tasks(
270        killswitch: &TaskKillswitch,
271    ) -> Vec<oneshot::Receiver<()>> {
272        (0..1000)
273            .map(|_| {
274                let (tx, rx) = TaskAbortSignal::new();
275
276                killswitch.spawn_task(async move {
277                    tokio::time::sleep(tokio::time::Duration::from_secs(3600))
278                        .await;
279                    drop(tx);
280                });
281
282                rx
283            })
284            .collect()
285    }
286
287    #[tokio::test]
288    async fn activate_killswitch_early() {
289        let killswitch = TaskKillswitch::with_leaked_storage();
290        let abort_signals = start_test_tasks(&killswitch);
291
292        killswitch.activate();
293
294        tokio::time::timeout(
295            Duration::from_secs(1),
296            future::join_all(abort_signals),
297        )
298        .await
299        .expect("tasks should be killed within given timeframe");
300    }
301
302    #[tokio::test]
303    async fn activate_killswitch_with_delay() {
304        let killswitch = TaskKillswitch::with_leaked_storage();
305        let abort_signals = start_test_tasks(&killswitch);
306        let signal_handle = tokio::spawn(killswitch.killed());
307
308        // NOTE: give tasks time to start executing.
309        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
310
311        assert!(!signal_handle.is_finished());
312        killswitch.activate();
313
314        tokio::time::timeout(
315            Duration::from_secs(1),
316            future::join_all(abort_signals),
317        )
318        .await
319        .expect("tasks should be killed within given timeframe");
320
321        tokio::time::timeout(Duration::from_secs(1), signal_handle)
322            .await
323            .expect("killed() signal should have resolved")
324            .expect("signal task should join successfully");
325    }
326}