spawns_core/
spawn.rs

1#[cfg(feature = "compat")]
2use crate::find_spawn;
3use crate::{JoinHandle, Name, Task};
4use std::cell::RefCell;
5use std::future::Future;
6use std::thread_local;
7
8thread_local! {
9    static SPAWNER: RefCell<Option<&'static dyn Spawn>> = RefCell::new(None);
10}
11
12/// Trait to spawn task.
13pub trait Spawn {
14    fn spawn(&self, task: Task);
15}
16
17/// Scope where tasks are [spawn]ed through given [Spawn].
18pub struct SpawnScope<'a> {
19    spawner: &'a dyn Spawn,
20    previous: Option<&'static dyn Spawn>,
21}
22
23fn exchange(spawner: Option<&dyn Spawn>) -> Option<&'static dyn Spawn> {
24    SPAWNER.with_borrow_mut(|previous| unsafe {
25        std::mem::replace(
26            previous,
27            std::mem::transmute::<Option<&dyn Spawn>, Option<&'static dyn Spawn>>(spawner),
28        )
29    })
30}
31
32/// Enters a scope where new tasks will be [spawn]ed through given [Spawn].
33pub fn enter(spawner: &dyn Spawn) -> SpawnScope<'_> {
34    let previous = exchange(Some(spawner));
35    SpawnScope { previous, spawner }
36}
37
38impl Drop for SpawnScope<'_> {
39    fn drop(&mut self) {
40        let current = exchange(self.previous.take()).expect("no spawner");
41        assert!(std::ptr::eq(self.spawner, current));
42    }
43}
44
45pub(crate) fn spawn_with_name<T, F>(name: Name, f: F) -> JoinHandle<T>
46where
47    F: Future<Output = T> + Send + 'static,
48    T: Send + 'static,
49{
50    SPAWNER
51        .with_borrow(|spawner| match spawner {
52            Some(spawner) => {
53                let (task, handle) = Task::new(name, f);
54                spawner.spawn(task);
55                Some(handle)
56            }
57            #[cfg(not(feature = "compat"))]
58            None => None,
59            #[cfg(feature = "compat")]
60            None => match find_spawn() {
61                Some(spawn) => {
62                    let (task, handle) = Task::new(name, f);
63                    spawn(task);
64                    Some(handle)
65                }
66                None => None,
67            },
68        })
69        .expect("no spawner")
70}
71
72/// Spawns a new task.
73///
74/// # Panics
75/// 1. Panic if no spawner.
76/// 2. Panic if [Spawn::spawn] panic.
77pub fn spawn<T, F>(f: F) -> JoinHandle<T>
78where
79    F: Future<Output = T> + Send + 'static,
80    T: Send + 'static,
81{
82    spawn_with_name(Name::default(), f)
83}
84
85#[cfg(test)]
86mod tests {
87    use crate::{enter, id, spawn, Builder, Spawn, Task};
88    use futures::executor::block_on;
89    use std::future::{pending, ready};
90
91    #[derive(Default, Clone, Copy)]
92    struct DropSpawner {}
93
94    impl Spawn for DropSpawner {
95        fn spawn(&self, _task: Task) {}
96    }
97
98    #[derive(Default, Clone, Copy)]
99    struct ThreadSpawner {}
100
101    impl Spawn for ThreadSpawner {
102        fn spawn(&self, task: Task) {
103            std::thread::Builder::new()
104                .name(task.name().to_string())
105                .spawn(move || {
106                    let spawner = ThreadSpawner::default();
107                    let _scope = enter(&spawner);
108                    block_on(Box::into_pin(task.future));
109                })
110                .unwrap();
111        }
112    }
113
114    #[cfg(not(feature = "compat"))]
115    #[test]
116    #[should_panic(expected = "no spawner")]
117    fn no_spawner() {
118        spawn(ready(()));
119    }
120
121    #[test]
122    fn drop_spawner() {
123        let spawner = DropSpawner::default();
124        let _scope = enter(&spawner);
125        let handle = spawn(ready(()));
126        let err = block_on(handle).unwrap_err();
127        assert!(err.is_cancelled());
128    }
129
130    #[test]
131    fn thread_spawner_named() {
132        let spawner = ThreadSpawner::default();
133        let _scope = enter(&spawner);
134        let handle = Builder::new()
135            .name("task1")
136            .spawn(async { std::thread::current().name().unwrap().to_string() });
137        let name = block_on(handle).unwrap();
138        assert_eq!(name, "task1");
139    }
140
141    #[test]
142    fn thread_spawner_unnamed() {
143        let spawner = ThreadSpawner::default();
144        let _scope = enter(&spawner);
145        let handle = spawn(async { std::thread::current().name().unwrap().to_string() });
146        let name = block_on(handle).unwrap();
147        assert_eq!(name, "unnamed");
148    }
149
150    #[test]
151    fn thread_spawner_cascading_ready() {
152        let spawner = ThreadSpawner::default();
153        let _scope = enter(&spawner);
154        #[allow(clippy::async_yields_async)]
155        let handle = spawn(async move { spawn(async { id() }) });
156        let handle = block_on(handle).unwrap();
157        let id = handle.id();
158        assert_eq!(block_on(handle).unwrap(), id);
159    }
160
161    #[test]
162    fn thread_spawner_cascading_cancel() {
163        let spawner = ThreadSpawner::default();
164        let _scope = enter(&spawner);
165        #[allow(clippy::async_yields_async)]
166        let handle = spawn(async move { spawn(pending::<()>()) });
167        let handle = block_on(handle).unwrap();
168        handle.cancel();
169        let err = block_on(handle).unwrap_err();
170        assert!(err.is_cancelled());
171    }
172
173    #[cfg(feature = "compat")]
174    mod compat {
175        use super::*;
176        use crate::{Compat, COMPATS};
177        use linkme::distributed_slice;
178        use std::cell::Cell;
179        thread_local! {
180            static DROP_SPAWNER: Cell<Option<DropSpawner>> = const {  Cell::new(None) };
181        }
182
183        #[distributed_slice(COMPATS)]
184        pub static DROP_LOCAL: Compat = Compat::Local(drop_local);
185
186        fn drop_spawn(task: Task) {
187            DROP_SPAWNER.get().expect("no drop spawner").spawn(task)
188        }
189
190        fn drop_local() -> Option<fn(Task)> {
191            DROP_SPAWNER.get().map(|_| drop_spawn as fn(Task))
192        }
193
194        thread_local! {
195            static THREAD_SPAWNER: Cell<Option<ThreadSpawner>> = const { Cell::new(None) };
196        }
197
198        #[distributed_slice(COMPATS)]
199        pub static THREAD_LOCAL: Compat = Compat::Local(thread_local);
200
201        #[cfg(feature = "test-compat-global1")]
202        #[distributed_slice(COMPATS)]
203        #[allow(deprecated)]
204        pub static THREAD_GLOBAL: Compat = Compat::Global(thread_global);
205
206        #[cfg(feature = "test-compat-global2")]
207        #[distributed_slice(COMPATS)]
208        pub static DROP_GLOBAL: Compat = Compat::NamedGlobal {
209            name: "drop",
210            spawn: drop_global,
211        };
212
213        #[cfg(feature = "test-compat-global2")]
214        fn drop_global(task: Task) {
215            DropSpawner::default().spawn(task)
216        }
217
218        fn thread_spawn(task: Task) {
219            THREAD_SPAWNER.get().expect("no thread spawner").spawn(task)
220        }
221
222        fn thread_local() -> Option<fn(Task)> {
223            THREAD_SPAWNER.get().map(|_| thread_spawn as fn(Task))
224        }
225
226        #[cfg(feature = "test-compat-global1")]
227        fn thread_global(task: Task) {
228            ThreadSpawner::default().spawn(task)
229        }
230
231        #[test]
232        #[cfg(not(any(feature = "test-compat-global1", feature = "test-compat-global2")))]
233        #[should_panic(expected = "no spawner")]
234        fn no_spawner() {
235            spawn(ready(()));
236        }
237
238        #[test]
239        fn drop_spawner_local() {
240            DROP_SPAWNER.set(Some(DropSpawner::default()));
241            let handle = spawn(ready(()));
242            let err = block_on(handle).unwrap_err();
243            assert!(err.is_cancelled());
244        }
245
246        #[test]
247        fn thread_spawner_local() {
248            THREAD_SPAWNER.set(Some(ThreadSpawner::default()));
249            let handle = Builder::new()
250                .name("task2")
251                .spawn(async { std::thread::current().name().unwrap().to_string() });
252            let name = block_on(handle).unwrap();
253            assert_eq!(name, "task2");
254        }
255
256        #[cfg(all(feature = "test-compat-global1", not(feature = "test-compat-global2")))]
257        #[test]
258        fn thread_spawner_global() {
259            let handle = Builder::new()
260                .name("thread_spawner_global")
261                .spawn(async { std::thread::current().name().unwrap().to_string() });
262            let name = block_on(handle).unwrap();
263            assert_eq!(name, "thread_spawner_global");
264        }
265
266        #[cfg(feature = "test-compat-global2")]
267        #[cfg(not(feature = "test-named-global"))]
268        #[cfg(feature = "panic-multiple-global-spawners")]
269        #[test]
270        #[should_panic(expected = "multiple global spawners")]
271        fn multiple_globals() {
272            spawn(ready(()));
273        }
274
275        #[cfg(feature = "test-compat-global2")]
276        #[cfg(not(feature = "test-named-global"))]
277        #[cfg(not(feature = "panic-multiple-global-spawners"))]
278        #[test]
279        fn multiple_globals() {
280            // The one chosen is indeterminate.
281            spawn(ready(()));
282        }
283
284        // Rust runs all tests in one process for given features, so it is crucial to keep features
285        // set unique for this test as it setup environment variable SPAWNS_GLOBAL_SPAWNER.
286        #[cfg(feature = "test-compat-global2")]
287        #[cfg(feature = "test-named-global")]
288        #[cfg(feature = "panic-multiple-global-spawners")]
289        #[test]
290        fn multiple_globals_choose_named() {
291            std::env::set_var("SPAWNS_GLOBAL_SPAWNER", "drop");
292            let handle = spawn(ready(()));
293            let err = block_on(handle).unwrap_err();
294            assert!(err.is_cancelled());
295        }
296    }
297}