1#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
38
39use std::sync::Arc;
40use std::sync::atomic::{AtomicBool, AtomicUsize};
41
42use tokio_util::sync::CancellationToken;
43
44mod ext;
46
47pub use ext::*;
48
49#[derive(Debug)]
51struct ContextTracker(Arc<ContextTrackerInner>);
52
53impl Drop for ContextTracker {
54 fn drop(&mut self) {
55 let prev_active_count = self.0.active_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
56 if prev_active_count == 1 && self.0.stopped.load(std::sync::atomic::Ordering::Relaxed) {
59 self.0.notify.notify_waiters();
60 }
61 }
62}
63
64#[derive(Debug)]
65struct ContextTrackerInner {
66 stopped: AtomicBool,
67 active_count: AtomicUsize,
70 notify: tokio::sync::Notify,
71}
72
73impl ContextTrackerInner {
74 fn new() -> Arc<Self> {
75 Arc::new(Self {
76 stopped: AtomicBool::new(false),
77 active_count: AtomicUsize::new(0),
78 notify: tokio::sync::Notify::new(),
79 })
80 }
81
82 fn child(self: &Arc<Self>) -> ContextTracker {
84 self.active_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
85 ContextTracker(Arc::clone(self))
86 }
87
88 fn stop(&self) {
90 self.stopped.store(true, std::sync::atomic::Ordering::Relaxed);
91 }
92
93 async fn wait(&self) {
96 let notify = self.notify.notified();
97
98 if self.active_count.load(std::sync::atomic::Ordering::Relaxed) == 0 {
100 return;
101 }
102
103 notify.await;
104 }
105}
106
107#[derive(Debug)]
119pub struct Context {
120 token: CancellationToken,
121 tracker: ContextTracker,
122}
123
124impl Clone for Context {
125 fn clone(&self) -> Self {
126 Self {
127 token: self.token.clone(),
128 tracker: self.tracker.0.child(),
129 }
130 }
131}
132
133impl Context {
134 #[must_use]
135 pub fn new() -> (Self, Handler) {
138 Handler::global().new_child()
139 }
140
141 #[must_use]
142 pub fn new_child(&self) -> (Self, Handler) {
154 let token = self.token.child_token();
155 let tracker = ContextTrackerInner::new();
156
157 (
158 Self {
159 tracker: tracker.child(),
160 token: token.clone(),
161 },
162 Handler {
163 token: Arc::new(TokenDropGuard(token)),
164 tracker,
165 },
166 )
167 }
168
169 #[must_use]
170 pub fn global() -> Self {
172 Handler::global().context()
173 }
174
175 pub async fn done(&self) {
177 self.token.cancelled().await;
178 }
179
180 pub async fn into_done(self) {
182 self.done().await;
183 }
184
185 #[must_use]
187 pub fn is_done(&self) -> bool {
188 self.token.is_cancelled()
189 }
190}
191
192#[derive(Debug)]
195struct TokenDropGuard(CancellationToken);
196
197impl TokenDropGuard {
198 #[must_use]
199 fn child(&self) -> CancellationToken {
200 self.0.child_token()
201 }
202
203 fn cancel(&self) {
204 self.0.cancel();
205 }
206}
207
208impl Drop for TokenDropGuard {
209 fn drop(&mut self) {
210 self.cancel();
211 }
212}
213
214#[derive(Debug, Clone)]
215pub struct Handler {
216 token: Arc<TokenDropGuard>,
217 tracker: Arc<ContextTrackerInner>,
218}
219
220impl Default for Handler {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226impl Handler {
227 #[must_use]
228 pub fn new() -> Handler {
230 let token = CancellationToken::new();
231 let tracker = ContextTrackerInner::new();
232
233 Handler {
234 token: Arc::new(TokenDropGuard(token)),
235 tracker,
236 }
237 }
238
239 #[must_use]
240 pub fn global() -> &'static Self {
242 static GLOBAL: std::sync::OnceLock<Handler> = std::sync::OnceLock::new();
243
244 GLOBAL.get_or_init(Handler::new)
245 }
246
247 pub async fn shutdown(&self) {
249 self.cancel();
250 self.done().await;
251 }
252
253 pub async fn done(&self) {
255 self.token.0.cancelled().await;
256 self.wait().await;
257 }
258
259 pub async fn wait(&self) {
263 self.tracker.wait().await;
264 }
265
266 #[must_use]
267 pub fn context(&self) -> Context {
269 Context {
270 token: self.token.child(),
271 tracker: self.tracker.child(),
272 }
273 }
274
275 #[must_use]
276 pub fn new_child(&self) -> (Context, Handler) {
278 self.context().new_child()
279 }
280
281 pub fn cancel(&self) {
283 self.tracker.stop();
284 self.token.cancel();
285 }
286
287 pub fn is_done(&self) -> bool {
289 self.token.0.is_cancelled()
290 }
291}
292
293#[cfg_attr(all(coverage_nightly, test), coverage(off))]
294#[cfg(test)]
295mod tests {
296 use scuffle_future_ext::FutureExt;
297
298 use crate::{Context, Handler};
299
300 #[tokio::test]
301 async fn new() {
302 let (ctx, handler) = Context::new();
303 assert!(!handler.is_done());
304 assert!(!ctx.is_done());
305
306 let handler = Handler::default();
307 assert!(!handler.is_done());
308 }
309
310 #[tokio::test]
311 async fn cancel() {
312 let (ctx, handler) = Context::new();
313 let (child_ctx, child_handler) = ctx.new_child();
314 let child_ctx2 = ctx.clone();
315
316 assert!(!handler.is_done());
317 assert!(!ctx.is_done());
318 assert!(!child_handler.is_done());
319 assert!(!child_ctx.is_done());
320 assert!(!child_ctx2.is_done());
321
322 handler.cancel();
323
324 assert!(handler.is_done());
325 assert!(ctx.is_done());
326 assert!(child_handler.is_done());
327 assert!(child_ctx.is_done());
328 assert!(child_ctx2.is_done());
329 }
330
331 #[tokio::test]
332 async fn cancel_child() {
333 let (ctx, handler) = Context::new();
334 let (child_ctx, child_handler) = ctx.new_child();
335
336 assert!(!handler.is_done());
337 assert!(!ctx.is_done());
338 assert!(!child_handler.is_done());
339 assert!(!child_ctx.is_done());
340
341 child_handler.cancel();
342
343 assert!(!handler.is_done());
344 assert!(!ctx.is_done());
345 assert!(child_handler.is_done());
346 assert!(child_ctx.is_done());
347 }
348
349 #[tokio::test]
350 async fn shutdown() {
351 let (ctx, handler) = Context::new();
352
353 assert!(!handler.is_done());
354 assert!(!ctx.is_done());
355
356 assert!(
358 handler
359 .shutdown()
360 .with_timeout(std::time::Duration::from_millis(200))
361 .await
362 .is_err()
363 );
364 assert!(handler.is_done());
365 assert!(ctx.is_done());
366 assert!(
367 ctx.into_done()
368 .with_timeout(std::time::Duration::from_millis(200))
369 .await
370 .is_ok()
371 );
372
373 assert!(
374 handler
375 .shutdown()
376 .with_timeout(std::time::Duration::from_millis(200))
377 .await
378 .is_ok()
379 );
380 assert!(
381 handler
382 .wait()
383 .with_timeout(std::time::Duration::from_millis(200))
384 .await
385 .is_ok()
386 );
387 assert!(
388 handler
389 .done()
390 .with_timeout(std::time::Duration::from_millis(200))
391 .await
392 .is_ok()
393 );
394 assert!(handler.is_done());
395 }
396
397 #[tokio::test]
398 async fn global_handler() {
399 let handler = Handler::global();
400
401 assert!(!handler.is_done());
402
403 handler.cancel();
404
405 assert!(handler.is_done());
406 assert!(Handler::global().is_done());
407 assert!(Context::global().is_done());
408
409 let (child_ctx, child_handler) = Handler::global().new_child();
410 assert!(child_handler.is_done());
411 assert!(child_ctx.is_done());
412 }
413}