task_killswitch/
lib.rs

1// Copyright (C) 2025, Cloudflare, Inc.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without
5// modification, are permitted provided that the following conditions are
6// met:
7//
8//     * Redistributions of source code must retain the above copyright notice,
9//       this list of conditions and the following disclaimer.
10//
11//     * Redistributions in binary form must reproduce the above copyright
12//       notice, this list of conditions and the following disclaimer in the
13//       documentation and/or other materials provided with the distribution.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
16// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
17// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
19// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27use once_cell::sync::Lazy;
28use tokio::sync::mpsc;
29use tokio::sync::watch;
30use tokio::task::JoinHandle;
31
32use std::collections::HashMap;
33use std::future::Future;
34use std::sync::atomic::AtomicU64;
35use std::sync::atomic::Ordering;
36
37enum ActiveTaskOp {
38    Add { id: u64, handle: JoinHandle<()> },
39    Remove { id: u64 },
40}
41
42/// Drop guard for task removal. If a task panics, this makes sure
43/// it is removed from [`ActiveTasks`] properly.
44struct RemoveOnDrop {
45    id: u64,
46    task_tx_weak: mpsc::WeakUnboundedSender<ActiveTaskOp>,
47}
48impl Drop for RemoveOnDrop {
49    fn drop(&mut self) {
50        if let Some(tx) = self.task_tx_weak.upgrade() {
51            let _ = tx.send(ActiveTaskOp::Remove { id: self.id });
52        }
53    }
54}
55
56/// A task killswitch that allows aborting all the tasks spawned with it at
57/// once. The implementation strives to not introduce any in-band locking, so
58/// spawning the future doesn't require acquiring a global lock, keeping the
59/// regular pace of operation.
60struct TaskKillswitch {
61    // NOTE: use a lock without poisoning here to not panic all the threads if
62    // one of the worker threads panic.
63    task_tx: parking_lot::RwLock<Option<mpsc::UnboundedSender<ActiveTaskOp>>>,
64    task_counter: AtomicU64,
65    all_killed: watch::Receiver<()>,
66}
67
68impl TaskKillswitch {
69    fn new() -> Self {
70        let (task_tx, task_rx) = mpsc::unbounded_channel();
71        let (signal_killed, all_killed) = watch::channel(());
72
73        let active_tasks = ActiveTasks {
74            task_rx,
75            tasks: Default::default(),
76            signal_killed,
77        };
78        tokio::spawn(active_tasks.collect());
79
80        Self {
81            task_tx: parking_lot::RwLock::new(Some(task_tx)),
82            task_counter: Default::default(),
83            all_killed,
84        }
85    }
86
87    fn spawn_task(&self, fut: impl Future<Output = ()> + Send + 'static) {
88        // NOTE: acquiring the lock here is very cheap, as unless the killswitch
89        // is activated, this one is always unlocked and this is just a
90        // few atomic operations.
91        let Some(task_tx) = self.task_tx.read().as_ref().cloned() else {
92            return;
93        };
94
95        let id = self.task_counter.fetch_add(1, Ordering::SeqCst);
96        let task_tx_weak = task_tx.downgrade();
97
98        let handle = tokio::spawn(async move {
99            // NOTE: we use a weak sender inside the spawned task - dropping
100            // all strong senders activates the killswitch. In that case,
101            // we don't need to remove anything from ActiveTasks anymore.
102            let _guard = RemoveOnDrop { task_tx_weak, id };
103            fut.await;
104        });
105
106        let _ = task_tx.send(ActiveTaskOp::Add { id, handle });
107    }
108
109    fn activate(&self) {
110        // take()ing the sender here drops it and thereby triggers the killswitch.
111        // Concurrent spawn_task calls may still hold strong senders, which
112        // ensures those tasks are added to ActiveTasks before the killing
113        // starts.
114        assert!(
115            self.task_tx.write().take().is_some(),
116            "killswitch can't be used twice"
117        );
118    }
119
120    fn killed(&self) -> impl Future<Output = ()> + Send + 'static {
121        let mut signal = self.all_killed.clone();
122        async move {
123            let _ = signal.changed().await;
124        }
125    }
126}
127
128struct ActiveTasks {
129    task_rx: mpsc::UnboundedReceiver<ActiveTaskOp>,
130    tasks: HashMap<u64, JoinHandle<()>>,
131    signal_killed: watch::Sender<()>,
132}
133
134impl ActiveTasks {
135    async fn collect(mut self) {
136        while let Some(op) = self.task_rx.recv().await {
137            self.handle_task_op(op);
138        }
139
140        for task in self.tasks.into_values() {
141            task.abort();
142        }
143        drop(self.signal_killed);
144    }
145
146    fn handle_task_op(&mut self, op: ActiveTaskOp) {
147        match op {
148            ActiveTaskOp::Add { id, handle } => {
149                self.tasks.insert(id, handle);
150            },
151            ActiveTaskOp::Remove { id } => {
152                self.tasks.remove(&id);
153            },
154        }
155    }
156}
157
158/// The global [`TaskKillswitch`] exposed publicly from the crate.
159static TASK_KILLSWITCH: Lazy<TaskKillswitch> = Lazy::new(TaskKillswitch::new);
160
161/// Spawns a new asynchronous task and registers it in the crate's global
162/// killswitch.
163///
164/// Under the hood, [`tokio::spawn`] schedules the actual execution.
165#[inline]
166pub fn spawn_with_killswitch(fut: impl Future<Output = ()> + Send + 'static) {
167    TASK_KILLSWITCH.spawn_task(fut);
168}
169
170#[deprecated = "activate() was unnecessarily declared async. Use activate_now() instead."]
171pub async fn activate() {
172    TASK_KILLSWITCH.activate()
173}
174
175/// Triggers the killswitch, thereby scheduling all registered tasks to be
176/// killed.
177///
178/// Note: tasks are not killed synchronously in this function. This means
179/// `activate_now()` will return before all tasks have been stopped.
180#[inline]
181pub fn activate_now() {
182    TASK_KILLSWITCH.activate();
183}
184
185/// Returns a future that resolves when all registered tasks have been killed,
186/// after [`activate_now`] has been called.
187///
188/// Note: tokio does not kill a task until the next time it yields to the
189/// runtime. This means some killed tasks may still be running by the time this
190/// Future resolves.
191#[inline]
192pub fn killed_signal() -> impl Future<Output = ()> + Send + 'static {
193    TASK_KILLSWITCH.killed()
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use futures_util::future;
200    use std::time::Duration;
201    use tokio::sync::oneshot;
202
203    struct TaskAbortSignal(Option<oneshot::Sender<()>>);
204
205    impl TaskAbortSignal {
206        fn new() -> (Self, oneshot::Receiver<()>) {
207            let (tx, rx) = oneshot::channel();
208
209            (Self(Some(tx)), rx)
210        }
211    }
212
213    impl Drop for TaskAbortSignal {
214        fn drop(&mut self) {
215            let _ = self.0.take().unwrap().send(());
216        }
217    }
218
219    fn start_test_tasks(
220        killswitch: &TaskKillswitch,
221    ) -> Vec<oneshot::Receiver<()>> {
222        (0..1000)
223            .map(|_| {
224                let (tx, rx) = TaskAbortSignal::new();
225
226                killswitch.spawn_task(async move {
227                    tokio::time::sleep(tokio::time::Duration::from_secs(3600))
228                        .await;
229                    drop(tx);
230                });
231
232                rx
233            })
234            .collect()
235    }
236
237    #[tokio::test]
238    async fn activate_killswitch_early() {
239        let killswitch = TaskKillswitch::new();
240        let abort_signals = start_test_tasks(&killswitch);
241
242        killswitch.activate();
243
244        tokio::time::timeout(
245            Duration::from_secs(1),
246            future::join_all(abort_signals),
247        )
248        .await
249        .expect("tasks should be killed within given timeframe");
250    }
251
252    #[tokio::test]
253    async fn activate_killswitch_with_delay() {
254        let killswitch = TaskKillswitch::new();
255        let abort_signals = start_test_tasks(&killswitch);
256        let signal_handle = tokio::spawn(killswitch.killed());
257
258        // NOTE: give tasks time to start executing.
259        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
260
261        assert!(!signal_handle.is_finished());
262        killswitch.activate();
263
264        tokio::time::timeout(
265            Duration::from_secs(1),
266            future::join_all(abort_signals),
267        )
268        .await
269        .expect("tasks should be killed within given timeframe");
270
271        tokio::time::timeout(Duration::from_secs(1), signal_handle)
272            .await
273            .expect("killed() signal should have resolved")
274            .expect("signal task should join successfully");
275    }
276}