Skip to main content

wsio_core/traits/task/
spawner.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use tokio::{
5    select,
6    spawn,
7};
8use tokio_util::sync::CancellationToken;
9
10pub trait TaskSpawner: Send + Sync + 'static {
11    fn cancel_token(&self) -> Arc<CancellationToken>;
12
13    #[inline]
14    fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
15        let cancel_token = self.cancel_token();
16        spawn(async move {
17            select! {
18                _ = cancel_token.cancelled() => {},
19                _ = future => {},
20            }
21        });
22    }
23}
24
25#[cfg(test)]
26mod tests {
27    use std::sync::atomic::{
28        AtomicBool,
29        Ordering,
30    };
31
32    use tokio::{
33        sync::oneshot::channel,
34        task::yield_now,
35    };
36
37    use super::*;
38
39    struct TestSpawner {
40        cancel_token: Arc<CancellationToken>,
41    }
42
43    impl TaskSpawner for TestSpawner {
44        fn cancel_token(&self) -> Arc<CancellationToken> {
45            self.cancel_token.clone()
46        }
47    }
48
49    #[tokio::test]
50    async fn test_spawn_task_runs_to_completion() {
51        let spawner = TestSpawner {
52            cancel_token: Arc::new(CancellationToken::new()),
53        };
54
55        let flag = Arc::new(AtomicBool::new(false));
56        let flag_clone = flag.clone();
57
58        let (tx, rx) = channel::<()>();
59
60        spawner.spawn_task(async move {
61            let _ = rx.await;
62            flag_clone.store(true, Ordering::Relaxed);
63            Ok(())
64        });
65
66        // Trigger the task to complete
67        let _ = tx.send(());
68
69        // Wait for the task to complete
70        yield_now().await;
71
72        assert!(flag.load(Ordering::Relaxed), "Task should have completed");
73    }
74
75    #[tokio::test]
76    async fn test_spawn_task_is_cancelled() {
77        let cancel_token = Arc::new(CancellationToken::new());
78        let spawner = TestSpawner {
79            cancel_token: cancel_token.clone(),
80        };
81
82        let flag = Arc::new(AtomicBool::new(false));
83        let flag_clone = flag.clone();
84
85        // Cancel the token immediately
86        cancel_token.cancel();
87
88        spawner.spawn_task(async move {
89            std::future::pending::<()>().await;
90            flag_clone.store(true, Ordering::Relaxed);
91            Ok(())
92        });
93
94        // Wait a bit to ensure the task had time to be aborted or complete if it failed to abort
95        yield_now().await;
96
97        assert!(
98            !flag.load(Ordering::Relaxed),
99            "Task should have been cancelled before completion"
100        );
101    }
102}