1
2
3
4#![no_std]
48
49#[cfg(feature = "alloc")]
50extern crate alloc;
51
52use core::task::Poll;
53use core::{future::Future, sync::atomic::AtomicUsize};
54use core::{ops::Deref, sync::atomic::Ordering::SeqCst};
55
56use futures_util::task::AtomicWaker;
57
58pub struct TaskCollection<S, T> {
59 spawner: S,
60 tracker: T,
61}
62
63impl<S> TaskCollection<S, ()>
64where
65 S: Spawner,
66{
67 pub fn with_static_tracker(
68 spawner: S,
69 tracker: &'static Tracker,
70 ) -> TaskCollection<S, &'static Tracker> {
71 TaskCollection { spawner, tracker }
72 }
73
74 #[cfg(feature = "alloc")]
75 pub fn new(spawner: S) -> TaskCollection<S, alloc::sync::Arc<Tracker>> {
76 TaskCollection {
77 spawner,
78 tracker: alloc::sync::Arc::new(Tracker::new()),
79 }
80 }
81}
82
83impl<S, T> TaskCollection<S, T>
84where
85 S: Spawner,
86 T: 'static + Deref<Target = Tracker> + Clone + Send,
87{
88 pub fn spawn<F, R>(&self, future: F)
89 where
90 F: Future<Output = R> + Send + 'static,
91 {
92 let tracker = self.create_task();
93 self.spawner.spawn(async {
94 let _ = future.await;
95 core::mem::drop(tracker);
96 });
97 }
98
99 fn create_task(&self) -> Task<T> {
100 let mut current_tasks = self.tracker.active_tasks.load(SeqCst);
101
102 loop {
103 if current_tasks == usize::MAX {
104 panic!();
105 }
106
107 let new_tasks = current_tasks + 1;
108
109 let actual_current =
110 self.tracker
111 .active_tasks
112 .compare_and_swap(current_tasks, new_tasks, SeqCst);
113
114 if current_tasks == actual_current {
115 return Task {
116 inner: self.tracker.clone(),
117 };
118 }
119
120 current_tasks = actual_current;
121 }
122 }
123}
124
125impl<S, T> Future for TaskCollection<S, T>
126where
127 T: core::ops::Deref<Target = Tracker>,
128{
129 type Output = ();
130
131 fn poll(
132 self: core::pin::Pin<&mut Self>,
133 cx: &mut core::task::Context<'_>,
134 ) -> core::task::Poll<Self::Output> {
135 let active_tasks = self.tracker.active_tasks.load(SeqCst);
136
137 if active_tasks == 0 {
138 Poll::Ready(())
139 } else {
140 self.tracker.waker.register(cx.waker());
141
142 let active_tasks = self.tracker.active_tasks.load(SeqCst);
143 if active_tasks == 0 {
144 Poll::Ready(())
145 } else {
146 Poll::Pending
147 }
148 }
149 }
150}
151
152struct Task<T>
153where
154 T: Deref<Target = Tracker>,
155{
156 inner: T,
157}
158
159impl<T> Drop for Task<T>
160where
161 T: Deref<Target = Tracker>,
162{
163 fn drop(&mut self) {
164 let previous = self.inner.active_tasks.fetch_sub(1, SeqCst);
165
166 if previous == 1 {
167 self.inner.waker.wake();
168 }
169 }
170}
171
172pub struct Tracker {
173 waker: AtomicWaker,
174 active_tasks: AtomicUsize,
175}
176
177impl Tracker {
178 pub const fn new() -> Tracker {
179 Tracker {
180 waker: AtomicWaker::new(),
181 active_tasks: AtomicUsize::new(0),
182 }
183 }
184}
185
186pub trait Spawner {
187 fn spawn<F>(&self, future: F)
188 where
189 F: Future<Output = ()> + Send + 'static;
190}
191
192#[cfg(feature = "smol")]
193impl Spawner for &smol::Executor<'_> {
194 fn spawn<F>(&self, future: F)
195 where
196 F: core::future::Future<Output = ()> + Send + 'static,
197 {
198 smol::Executor::spawn(self, future).detach();
199 }
200}
201
202#[cfg(feature = "tokio")]
203impl Spawner for &tokio::runtime::Runtime {
204 fn spawn<F>(&self, future: F)
205 where
206 F: core::future::Future<Output = ()> + Send + 'static,
207 {
208 tokio::runtime::Runtime::spawn(self, future);
209 }
210}
211
212#[cfg(feature = "tokio")]
213pub struct GlobalTokioSpawner;
214
215#[cfg(feature = "tokio")]
216impl Spawner for GlobalTokioSpawner {
217 fn spawn<F>(&self, future: F)
218 where
219 F: core::future::Future<Output = ()> + Send + 'static,
220 {
221 tokio::spawn(future);
222 }
223}
224
225#[cfg(feature = "async-std")]
226pub struct AsyncStdSpawner;
227
228#[cfg(feature = "async-std")]
229impl Spawner for AsyncStdSpawner {
230 fn spawn<F>(&self, future: F)
231 where
232 F: core::future::Future<Output = ()> + Send + 'static,
233 {
234 async_std::task::spawn(future);
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 extern crate std;
241 use crate::{TaskCollection, Tracker};
242 use core::panic;
243 use smol::future::FutureExt;
244 use std::time::Duration;
245
246 #[test]
247 #[cfg(feature = "smol")]
248 fn test_smol() {
249 let exec = smol::Executor::new();
250
251 let f = async {
252 let collection = TaskCollection::new(&exec);
253
254 for i in &[5, 3, 1, 4, 2] {
255 collection.spawn(async move {
256 smol::Timer::after(Duration::from_secs(*i)).await;
257 });
258 }
259
260 collection.await;
261 };
262
263 let timeout = async {
264 smol::Timer::after(Duration::from_secs(10)).await;
265 panic!();
266 };
267
268 smol::block_on(exec.run(f.or(timeout)));
269 }
270
271 #[test]
272 #[cfg(feature = "smol")]
273 fn test_smol_static() {
274 let exec = smol::Executor::new();
275 static T: Tracker = Tracker::new();
276 let f = async {
277 let collection = TaskCollection::with_static_tracker(&exec, &T);
278
279 for i in &[5, 3, 1, 4, 2] {
280 collection.spawn(async move {
281 smol::Timer::after(Duration::from_secs(*i)).await;
282 });
283 }
284
285 collection.await;
286 };
287
288 let timeout = async {
289 smol::Timer::after(Duration::from_secs(10)).await;
290 panic!();
291 };
292
293 smol::block_on(exec.run(f.or(timeout)));
294 }
295
296 #[test]
297 #[cfg(feature = "tokio")]
298 fn test_tokio() {
299 let runtime = tokio::runtime::Runtime::new().unwrap();
300
301 let f = async {
302 let collection = TaskCollection::new(&runtime);
303
304 for i in &[5, 3, 1, 4, 2] {
305 collection.spawn(async move {
306 tokio::time::sleep(Duration::from_secs(*i)).await;
307 });
308 }
309
310 collection.await;
311 };
312
313 runtime.block_on(async {
314 tokio::select! {
315 _ = f => (),
316 _ = tokio::time::sleep(Duration::from_secs(10)) => panic!()
317 }
318 });
319 }
320
321 #[test]
322 #[cfg(feature = "async-std")]
323 fn test_async_std() {
324 use crate::AsyncStdSpawner;
325 let f = async {
326 let collection = TaskCollection::new(AsyncStdSpawner);
327
328 for i in &[5, 3, 1, 4, 2] {
329 collection.spawn(async move {
330 async_std::task::sleep(Duration::from_secs(*i)).await;
331 });
332 }
333
334 collection.await;
335 };
336
337 async_std::task::block_on(async_std::future::timeout(Duration::from_secs(10), f)).unwrap();
338 }
339}