widgetkit_runtime/
tasks.rs1use crate::internal::Dispatcher;
2use futures::Future;
3#[cfg(not(feature = "runtime-tokio"))]
4use futures::future::{AbortHandle, Abortable};
5use std::{collections::HashMap, pin::Pin};
6#[cfg(not(feature = "runtime-tokio"))]
7use std::thread;
8use widgetkit_core::TaskId;
9
10pub struct Tasks<'a, M> {
11 backend: &'a mut dyn TaskBackend<M>,
12}
13
14impl<'a, M> Tasks<'a, M>
15where
16 M: Send + 'static,
17{
18 pub(crate) fn new(backend: &'a mut dyn TaskBackend<M>) -> Self {
19 Self { backend }
20 }
21
22 pub fn spawn<F>(&mut self, future: F) -> TaskId
23 where
24 F: Future<Output = M> + Send + 'static,
25 {
26 self.backend.spawn_boxed(None, Box::pin(future))
27 }
28
29 pub fn spawn_named<F>(&mut self, name: impl Into<String>, future: F) -> TaskId
30 where
31 F: Future<Output = M> + Send + 'static,
32 {
33 self.backend.spawn_boxed(Some(name.into()), Box::pin(future))
34 }
35
36 pub fn cancel(&mut self, task_id: TaskId) -> bool {
37 self.backend.cancel(task_id)
38 }
39
40 pub fn cancel_all(&mut self) {
41 self.backend.cancel_all();
42 }
43}
44
45pub(crate) type BoxedFuture<M> = Pin<Box<dyn Future<Output = M> + Send + 'static>>;
46
47pub(crate) trait TaskBackend<M>: Send {
48 fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId;
49 fn cancel(&mut self, task_id: TaskId) -> bool;
50 fn cancel_all(&mut self);
51 fn reap(&mut self, task_id: TaskId);
52 fn shutdown(&mut self);
53 #[cfg(test)]
54 fn active_count(&self) -> usize;
55}
56
57pub(crate) fn task_backend<M>(dispatcher: Dispatcher<M>) -> Box<dyn TaskBackend<M>>
58where
59 M: Send + 'static,
60{
61 #[cfg(feature = "runtime-tokio")]
62 {
63 return Box::new(TokioTaskBackend::new(dispatcher));
64 }
65
66 #[cfg(not(feature = "runtime-tokio"))]
67 {
68 Box::new(DefaultTaskBackend::new(dispatcher))
69 }
70}
71
72#[cfg(not(feature = "runtime-tokio"))]
73struct DefaultTaskBackend<M> {
74 dispatcher: Dispatcher<M>,
75 tasks: HashMap<TaskId, DefaultTaskControl>,
76 shutting_down: bool,
77}
78
79#[cfg(not(feature = "runtime-tokio"))]
80struct DefaultTaskControl {
81 #[allow(dead_code)]
82 name: Option<String>,
83 abort_handle: AbortHandle,
84}
85
86#[cfg(not(feature = "runtime-tokio"))]
87impl<M> DefaultTaskBackend<M>
88where
89 M: Send + 'static,
90{
91 fn new(dispatcher: Dispatcher<M>) -> Self {
92 Self {
93 dispatcher,
94 tasks: HashMap::new(),
95 shutting_down: false,
96 }
97 }
98
99 fn close(&mut self) {
100 if self.shutting_down {
101 return;
102 }
103 self.shutting_down = true;
104 self.cancel_all();
105 }
106}
107
108#[cfg(not(feature = "runtime-tokio"))]
109impl<M> TaskBackend<M> for DefaultTaskBackend<M>
110where
111 M: Send + 'static,
112{
113 fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
114 let task_id = TaskId::new();
115 if self.shutting_down {
116 drop(future);
117 self.dispatcher.finish_task(task_id);
118 return task_id;
119 }
120
121 let (abort_handle, abort_registration) = AbortHandle::new_pair();
122 let dispatcher = self.dispatcher.clone();
123 thread::spawn(move || {
124 let future = Abortable::new(future, abort_registration);
125 if let Ok(message) = futures::executor::block_on(future) {
126 let _ = dispatcher.post_message(message);
127 }
128 dispatcher.finish_task(task_id);
129 });
130 self.tasks.insert(task_id, DefaultTaskControl { name, abort_handle });
131 task_id
132 }
133
134 fn cancel(&mut self, task_id: TaskId) -> bool {
135 if let Some(control) = self.tasks.remove(&task_id) {
136 control.abort_handle.abort();
137 return true;
138 }
139 false
140 }
141
142 fn cancel_all(&mut self) {
143 for (_, control) in self.tasks.drain() {
144 control.abort_handle.abort();
145 }
146 }
147
148 fn reap(&mut self, task_id: TaskId) {
149 let _ = self.tasks.remove(&task_id);
150 }
151
152 fn shutdown(&mut self) {
153 self.close();
154 }
155
156 #[cfg(test)]
157 fn active_count(&self) -> usize {
158 self.tasks.len()
159 }
160}
161
162#[cfg(not(feature = "runtime-tokio"))]
163impl<M> Drop for DefaultTaskBackend<M> {
164 fn drop(&mut self) {
165 self.shutting_down = true;
166 for (_, control) in self.tasks.drain() {
167 control.abort_handle.abort();
168 }
169 }
170}
171
172#[cfg(feature = "runtime-tokio")]
173struct TokioTaskBackend<M> {
174 dispatcher: Dispatcher<M>,
175 runtime: tokio::runtime::Runtime,
176 tasks: HashMap<TaskId, TokioTaskControl>,
177 shutting_down: bool,
178}
179
180#[cfg(feature = "runtime-tokio")]
181struct TokioTaskControl {
182 #[allow(dead_code)]
183 name: Option<String>,
184 join_handle: tokio::task::JoinHandle<()>,
185}
186
187#[cfg(feature = "runtime-tokio")]
188impl<M> TokioTaskBackend<M>
189where
190 M: Send + 'static,
191{
192 fn new(dispatcher: Dispatcher<M>) -> Self {
193 let runtime = tokio::runtime::Builder::new_multi_thread()
194 .enable_all()
195 .build()
196 .expect("tokio runtime backend must initialize");
197 Self {
198 dispatcher,
199 runtime,
200 tasks: HashMap::new(),
201 shutting_down: false,
202 }
203 }
204
205 fn close(&mut self) {
206 if self.shutting_down {
207 return;
208 }
209 self.shutting_down = true;
210 self.cancel_all();
211 }
212}
213
214#[cfg(feature = "runtime-tokio")]
215impl<M> TaskBackend<M> for TokioTaskBackend<M>
216where
217 M: Send + 'static,
218{
219 fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
220 let task_id = TaskId::new();
221 if self.shutting_down {
222 drop(future);
223 self.dispatcher.finish_task(task_id);
224 return task_id;
225 }
226
227 let dispatcher = self.dispatcher.clone();
228 let join_handle = self.runtime.spawn(async move {
229 let message = future.await;
230 let _ = dispatcher.post_message(message);
231 dispatcher.finish_task(task_id);
232 });
233 self.tasks.insert(task_id, TokioTaskControl { name, join_handle });
234 task_id
235 }
236
237 fn cancel(&mut self, task_id: TaskId) -> bool {
238 if let Some(control) = self.tasks.remove(&task_id) {
239 control.join_handle.abort();
240 return true;
241 }
242 false
243 }
244
245 fn cancel_all(&mut self) {
246 for (_, control) in self.tasks.drain() {
247 control.join_handle.abort();
248 }
249 }
250
251 fn reap(&mut self, task_id: TaskId) {
252 let _ = self.tasks.remove(&task_id);
253 }
254
255 fn shutdown(&mut self) {
256 self.close();
257 }
258
259 #[cfg(test)]
260 fn active_count(&self) -> usize {
261 self.tasks.len()
262 }
263}
264
265#[cfg(feature = "runtime-tokio")]
266impl<M> Drop for TokioTaskBackend<M> {
267 fn drop(&mut self) {
268 self.shutting_down = true;
269 for (_, control) in self.tasks.drain() {
270 control.join_handle.abort();
271 }
272 }
273}