1#[cfg(feature = "compat")]
2use crate::find_spawn;
3use crate::{JoinHandle, Name, Task};
4use std::cell::RefCell;
5use std::future::Future;
6use std::thread_local;
7
8thread_local! {
9 static SPAWNER: RefCell<Option<&'static dyn Spawn>> = RefCell::new(None);
10}
11
12pub trait Spawn {
14 fn spawn(&self, task: Task);
15}
16
17pub struct SpawnScope<'a> {
19 spawner: &'a dyn Spawn,
20 previous: Option<&'static dyn Spawn>,
21}
22
23fn exchange(spawner: Option<&dyn Spawn>) -> Option<&'static dyn Spawn> {
24 SPAWNER.with_borrow_mut(|previous| unsafe {
25 std::mem::replace(
26 previous,
27 std::mem::transmute::<Option<&dyn Spawn>, Option<&'static dyn Spawn>>(spawner),
28 )
29 })
30}
31
32pub fn enter(spawner: &dyn Spawn) -> SpawnScope<'_> {
34 let previous = exchange(Some(spawner));
35 SpawnScope { previous, spawner }
36}
37
38impl Drop for SpawnScope<'_> {
39 fn drop(&mut self) {
40 let current = exchange(self.previous.take()).expect("no spawner");
41 assert!(std::ptr::eq(self.spawner, current));
42 }
43}
44
45pub(crate) fn spawn_with_name<T, F>(name: Name, f: F) -> JoinHandle<T>
46where
47 F: Future<Output = T> + Send + 'static,
48 T: Send + 'static,
49{
50 SPAWNER
51 .with_borrow(|spawner| match spawner {
52 Some(spawner) => {
53 let (task, handle) = Task::new(name, f);
54 spawner.spawn(task);
55 Some(handle)
56 }
57 #[cfg(not(feature = "compat"))]
58 None => None,
59 #[cfg(feature = "compat")]
60 None => match find_spawn() {
61 Some(spawn) => {
62 let (task, handle) = Task::new(name, f);
63 spawn(task);
64 Some(handle)
65 }
66 None => None,
67 },
68 })
69 .expect("no spawner")
70}
71
72pub fn spawn<T, F>(f: F) -> JoinHandle<T>
78where
79 F: Future<Output = T> + Send + 'static,
80 T: Send + 'static,
81{
82 spawn_with_name(Name::default(), f)
83}
84
85#[cfg(test)]
86mod tests {
87 use crate::{enter, id, spawn, Builder, Spawn, Task};
88 use futures::executor::block_on;
89 use std::future::{pending, ready};
90
91 #[derive(Default, Clone, Copy)]
92 struct DropSpawner {}
93
94 impl Spawn for DropSpawner {
95 fn spawn(&self, _task: Task) {}
96 }
97
98 #[derive(Default, Clone, Copy)]
99 struct ThreadSpawner {}
100
101 impl Spawn for ThreadSpawner {
102 fn spawn(&self, task: Task) {
103 std::thread::Builder::new()
104 .name(task.name().to_string())
105 .spawn(move || {
106 let spawner = ThreadSpawner::default();
107 let _scope = enter(&spawner);
108 block_on(Box::into_pin(task.future));
109 })
110 .unwrap();
111 }
112 }
113
114 #[cfg(not(feature = "compat"))]
115 #[test]
116 #[should_panic(expected = "no spawner")]
117 fn no_spawner() {
118 spawn(ready(()));
119 }
120
121 #[test]
122 fn drop_spawner() {
123 let spawner = DropSpawner::default();
124 let _scope = enter(&spawner);
125 let handle = spawn(ready(()));
126 let err = block_on(handle).unwrap_err();
127 assert!(err.is_cancelled());
128 }
129
130 #[test]
131 fn thread_spawner_named() {
132 let spawner = ThreadSpawner::default();
133 let _scope = enter(&spawner);
134 let handle = Builder::new()
135 .name("task1")
136 .spawn(async { std::thread::current().name().unwrap().to_string() });
137 let name = block_on(handle).unwrap();
138 assert_eq!(name, "task1");
139 }
140
141 #[test]
142 fn thread_spawner_unnamed() {
143 let spawner = ThreadSpawner::default();
144 let _scope = enter(&spawner);
145 let handle = spawn(async { std::thread::current().name().unwrap().to_string() });
146 let name = block_on(handle).unwrap();
147 assert_eq!(name, "unnamed");
148 }
149
150 #[test]
151 fn thread_spawner_cascading_ready() {
152 let spawner = ThreadSpawner::default();
153 let _scope = enter(&spawner);
154 #[allow(clippy::async_yields_async)]
155 let handle = spawn(async move { spawn(async { id() }) });
156 let handle = block_on(handle).unwrap();
157 let id = handle.id();
158 assert_eq!(block_on(handle).unwrap(), id);
159 }
160
161 #[test]
162 fn thread_spawner_cascading_cancel() {
163 let spawner = ThreadSpawner::default();
164 let _scope = enter(&spawner);
165 #[allow(clippy::async_yields_async)]
166 let handle = spawn(async move { spawn(pending::<()>()) });
167 let handle = block_on(handle).unwrap();
168 handle.cancel();
169 let err = block_on(handle).unwrap_err();
170 assert!(err.is_cancelled());
171 }
172
173 #[cfg(feature = "compat")]
174 mod compat {
175 use super::*;
176 use crate::{Compat, COMPATS};
177 use linkme::distributed_slice;
178 use std::cell::Cell;
179 thread_local! {
180 static DROP_SPAWNER: Cell<Option<DropSpawner>> = const { Cell::new(None) };
181 }
182
183 #[distributed_slice(COMPATS)]
184 pub static DROP_LOCAL: Compat = Compat::Local(drop_local);
185
186 fn drop_spawn(task: Task) {
187 DROP_SPAWNER.get().expect("no drop spawner").spawn(task)
188 }
189
190 fn drop_local() -> Option<fn(Task)> {
191 DROP_SPAWNER.get().map(|_| drop_spawn as fn(Task))
192 }
193
194 thread_local! {
195 static THREAD_SPAWNER: Cell<Option<ThreadSpawner>> = const { Cell::new(None) };
196 }
197
198 #[distributed_slice(COMPATS)]
199 pub static THREAD_LOCAL: Compat = Compat::Local(thread_local);
200
201 #[cfg(feature = "test-compat-global1")]
202 #[distributed_slice(COMPATS)]
203 #[allow(deprecated)]
204 pub static THREAD_GLOBAL: Compat = Compat::Global(thread_global);
205
206 #[cfg(feature = "test-compat-global2")]
207 #[distributed_slice(COMPATS)]
208 pub static DROP_GLOBAL: Compat = Compat::NamedGlobal {
209 name: "drop",
210 spawn: drop_global,
211 };
212
213 #[cfg(feature = "test-compat-global2")]
214 fn drop_global(task: Task) {
215 DropSpawner::default().spawn(task)
216 }
217
218 fn thread_spawn(task: Task) {
219 THREAD_SPAWNER.get().expect("no thread spawner").spawn(task)
220 }
221
222 fn thread_local() -> Option<fn(Task)> {
223 THREAD_SPAWNER.get().map(|_| thread_spawn as fn(Task))
224 }
225
226 #[cfg(feature = "test-compat-global1")]
227 fn thread_global(task: Task) {
228 ThreadSpawner::default().spawn(task)
229 }
230
231 #[test]
232 #[cfg(not(any(feature = "test-compat-global1", feature = "test-compat-global2")))]
233 #[should_panic(expected = "no spawner")]
234 fn no_spawner() {
235 spawn(ready(()));
236 }
237
238 #[test]
239 fn drop_spawner_local() {
240 DROP_SPAWNER.set(Some(DropSpawner::default()));
241 let handle = spawn(ready(()));
242 let err = block_on(handle).unwrap_err();
243 assert!(err.is_cancelled());
244 }
245
246 #[test]
247 fn thread_spawner_local() {
248 THREAD_SPAWNER.set(Some(ThreadSpawner::default()));
249 let handle = Builder::new()
250 .name("task2")
251 .spawn(async { std::thread::current().name().unwrap().to_string() });
252 let name = block_on(handle).unwrap();
253 assert_eq!(name, "task2");
254 }
255
256 #[cfg(all(feature = "test-compat-global1", not(feature = "test-compat-global2")))]
257 #[test]
258 fn thread_spawner_global() {
259 let handle = Builder::new()
260 .name("thread_spawner_global")
261 .spawn(async { std::thread::current().name().unwrap().to_string() });
262 let name = block_on(handle).unwrap();
263 assert_eq!(name, "thread_spawner_global");
264 }
265
266 #[cfg(feature = "test-compat-global2")]
267 #[cfg(not(feature = "test-named-global"))]
268 #[cfg(feature = "panic-multiple-global-spawners")]
269 #[test]
270 #[should_panic(expected = "multiple global spawners")]
271 fn multiple_globals() {
272 spawn(ready(()));
273 }
274
275 #[cfg(feature = "test-compat-global2")]
276 #[cfg(not(feature = "test-named-global"))]
277 #[cfg(not(feature = "panic-multiple-global-spawners"))]
278 #[test]
279 fn multiple_globals() {
280 spawn(ready(()));
282 }
283
284 #[cfg(feature = "test-compat-global2")]
287 #[cfg(feature = "test-named-global")]
288 #[cfg(feature = "panic-multiple-global-spawners")]
289 #[test]
290 fn multiple_globals_choose_named() {
291 std::env::set_var("SPAWNS_GLOBAL_SPAWNER", "drop");
292 let handle = spawn(ready(()));
293 let err = block_on(handle).unwrap_err();
294 assert!(err.is_cancelled());
295 }
296 }
297}