1use core::{
2 future::Future,
3 ops::Deref,
4 pin::Pin,
5 sync::atomic::Ordering,
6 task::{Poll, Waker},
7};
8
9#[cfg(feature = "alloc")]
10use alloc::sync::Arc;
11
12use pin_project_lite::pin_project;
13
14use crate::{intrusive::Node, spawn::Spawn, task::Task, State};
15
16pub struct TaskGroup<S: Spawn, C: Deref<Target = State>> {
18 spawner: S,
19 state: C,
20}
21
22#[cfg(feature = "alloc")]
23impl<S: Spawn> TaskGroup<S, Arc<State>> {
24 pub fn new(spawner: S) -> Self {
26 TaskGroup {
27 spawner,
28 state: Arc::new(State::new()),
29 }
30 }
31}
32
33impl<S: Spawn> TaskGroup<S, &'static State> {
34 pub fn with_static(spawner: S, state: &'static State) -> Self {
36 TaskGroup { spawner, state }
37 }
38}
39
40impl<S: Spawn, C: 'static + Deref<Target = State> + Clone + Send> TaskGroup<S, C> {
41 pub async fn shutdown(&self) {
43 critical_section::with(|cs| {
44 self.state.shutdown_signaled.store(true, Ordering::SeqCst);
45
46 let list = self.state.shutdown_wakers.borrow(cs).borrow_mut();
47
48 let mut node = list.peek_front();
49 while let Some(inner_node) = node {
50 if let Some(ref waker) = inner_node.data {
51 waker.clone().wake();
52 }
53
54 node = inner_node.next();
55 }
56 });
57
58 self.done().await;
59 }
60
61 pub async fn done(&self) {
63 DoneFuture {
64 state: self.state.clone(),
65 }
66 .await
67 }
68
69 pub fn spawn(&self, future: impl Future<Output = ()> + Send + 'static) {
71 let task = Task::new(self.state.clone());
72 self.spawner.spawn(async {
73 future.await;
74 core::mem::drop(task);
75 });
76 }
77
78 pub fn spawn_with_shutdown<F>(&self, f: impl FnOnce(ShutdownSignal<C>) -> F)
79 where
80 F: Future<Output = ()> + Send + 'static,
81 {
82 let signal = ShutdownSignal {
83 state: self.state.clone(),
84 node: Node::new(None),
85 };
86 let future = f(signal);
87 self.spawn(future);
88 }
89}
90
91struct DoneFuture<C> {
92 state: C,
93}
94
95impl<C: Deref<Target = State>> Future for DoneFuture<C> {
96 type Output = ();
97
98 fn poll(
99 self: core::pin::Pin<&mut Self>,
100 cx: &mut core::task::Context<'_>,
101 ) -> core::task::Poll<Self::Output> {
102 self.state.done_waker.register(cx.waker());
103
104 if self.state.running_tasks.load(Ordering::SeqCst) == 0 {
105 Poll::Ready(())
106 } else {
107 Poll::Pending
108 }
109 }
110}
111
112pin_project! {
113 pub struct ShutdownSignal<C: Deref<Target = State>> {
115 state: C,
116 #[pin]
117 node: Node<Option<Waker>>,
118 }
119
120 impl<C: Deref<Target = State>> PinnedDrop for ShutdownSignal<C> {
121 fn drop(this: Pin<&mut Self>) {
122 let this = this.project();
123
124 critical_section::with(|cs| {
125 let mut list = this.state.shutdown_wakers.borrow(cs).borrow_mut();
126 if this.node.is_init() {
127 unsafe {this.node.remove(&mut list) };
128 }
129 });
130 }
131 }
132}
133
134impl<C: Deref<Target = State>> Future for ShutdownSignal<C> {
135 type Output = ();
136
137 fn poll(
138 self: core::pin::Pin<&mut Self>,
139 cx: &mut core::task::Context<'_>,
140 ) -> Poll<Self::Output> {
141 let mut this = self.project();
142 unsafe {
143 critical_section::with(|cs| {
144 if this.state.shutdown_signaled.load(Ordering::SeqCst) {
145 return Poll::Ready(());
146 }
147 let node = Pin::as_mut(&mut this.node).get_unchecked_mut();
148 node.data = Some(cx.waker().clone());
149 if !node.is_init() {
150 this.state
151 .shutdown_wakers
152 .borrow(cs)
153 .borrow_mut()
154 .push_front(this.node);
155 }
156 return Poll::Pending;
157 })
158 }
159 }
160}