Skip to main content

widgetkit_runtime/
tasks.rs

1use crate::internal::Dispatcher;
2use futures::Future;
3#[cfg(not(feature = "runtime-tokio"))]
4use futures::future::{AbortHandle, Abortable};
5#[cfg(not(feature = "runtime-tokio"))]
6use std::thread;
7use std::{collections::HashMap, pin::Pin};
8use widgetkit_core::TaskId;
9
10pub struct Tasks<'a, M> {
11    backend: &'a mut dyn TaskBackend<M>,
12}
13
14impl<'a, M> Tasks<'a, M>
15where
16    M: Send + 'static,
17{
18    pub(crate) fn new(backend: &'a mut dyn TaskBackend<M>) -> Self {
19        Self { backend }
20    }
21
22    pub fn spawn<F>(&mut self, future: F) -> TaskId
23    where
24        F: Future<Output = M> + Send + 'static,
25    {
26        self.backend.spawn_boxed(None, Box::pin(future))
27    }
28
29    pub fn spawn_named<F>(&mut self, name: impl Into<String>, future: F) -> TaskId
30    where
31        F: Future<Output = M> + Send + 'static,
32    {
33        self.backend
34            .spawn_boxed(Some(name.into()), Box::pin(future))
35    }
36
37    pub fn cancel(&mut self, task_id: TaskId) -> bool {
38        self.backend.cancel(task_id)
39    }
40
41    pub fn cancel_all(&mut self) {
42        self.backend.cancel_all();
43    }
44}
45
46pub(crate) type BoxedFuture<M> = Pin<Box<dyn Future<Output = M> + Send + 'static>>;
47
48pub(crate) trait TaskBackend<M>: Send {
49    fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId;
50    fn cancel(&mut self, task_id: TaskId) -> bool;
51    fn cancel_all(&mut self);
52    fn reap(&mut self, task_id: TaskId);
53    fn shutdown(&mut self);
54    #[cfg(test)]
55    fn active_count(&self) -> usize;
56}
57
58pub(crate) fn task_backend<M>(dispatcher: Dispatcher<M>) -> Box<dyn TaskBackend<M>>
59where
60    M: Send + 'static,
61{
62    #[cfg(feature = "runtime-tokio")]
63    {
64        return Box::new(TokioTaskBackend::new(dispatcher));
65    }
66
67    #[cfg(not(feature = "runtime-tokio"))]
68    {
69        Box::new(DefaultTaskBackend::new(dispatcher))
70    }
71}
72
73#[cfg(not(feature = "runtime-tokio"))]
74struct DefaultTaskBackend<M> {
75    dispatcher: Dispatcher<M>,
76    tasks: HashMap<TaskId, DefaultTaskControl>,
77    shutting_down: bool,
78}
79
80#[cfg(not(feature = "runtime-tokio"))]
81struct DefaultTaskControl {
82    #[allow(dead_code)]
83    name: Option<String>,
84    abort_handle: AbortHandle,
85}
86
87#[cfg(not(feature = "runtime-tokio"))]
88impl<M> DefaultTaskBackend<M>
89where
90    M: Send + 'static,
91{
92    fn new(dispatcher: Dispatcher<M>) -> Self {
93        Self {
94            dispatcher,
95            tasks: HashMap::new(),
96            shutting_down: false,
97        }
98    }
99
100    fn close(&mut self) {
101        if self.shutting_down {
102            return;
103        }
104        self.shutting_down = true;
105        self.cancel_all();
106    }
107}
108
109#[cfg(not(feature = "runtime-tokio"))]
110impl<M> TaskBackend<M> for DefaultTaskBackend<M>
111where
112    M: Send + 'static,
113{
114    fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
115        let task_id = TaskId::new();
116        if self.shutting_down {
117            drop(future);
118            self.dispatcher.finish_task(task_id);
119            return task_id;
120        }
121
122        let (abort_handle, abort_registration) = AbortHandle::new_pair();
123        let dispatcher = self.dispatcher.clone();
124        thread::spawn(move || {
125            let future = Abortable::new(future, abort_registration);
126            if let Ok(message) = futures::executor::block_on(future) {
127                let _ = dispatcher.post_message(message);
128            }
129            dispatcher.finish_task(task_id);
130        });
131        self.tasks
132            .insert(task_id, DefaultTaskControl { name, abort_handle });
133        task_id
134    }
135
136    fn cancel(&mut self, task_id: TaskId) -> bool {
137        if let Some(control) = self.tasks.remove(&task_id) {
138            control.abort_handle.abort();
139            return true;
140        }
141        false
142    }
143
144    fn cancel_all(&mut self) {
145        for (_, control) in self.tasks.drain() {
146            control.abort_handle.abort();
147        }
148    }
149
150    fn reap(&mut self, task_id: TaskId) {
151        let _ = self.tasks.remove(&task_id);
152    }
153
154    fn shutdown(&mut self) {
155        self.close();
156    }
157
158    #[cfg(test)]
159    fn active_count(&self) -> usize {
160        self.tasks.len()
161    }
162}
163
164#[cfg(not(feature = "runtime-tokio"))]
165impl<M> Drop for DefaultTaskBackend<M> {
166    fn drop(&mut self) {
167        self.shutting_down = true;
168        for (_, control) in self.tasks.drain() {
169            control.abort_handle.abort();
170        }
171    }
172}
173
174#[cfg(feature = "runtime-tokio")]
175struct TokioTaskBackend<M> {
176    dispatcher: Dispatcher<M>,
177    runtime: tokio::runtime::Runtime,
178    tasks: HashMap<TaskId, TokioTaskControl>,
179    shutting_down: bool,
180}
181
182#[cfg(feature = "runtime-tokio")]
183struct TokioTaskControl {
184    #[allow(dead_code)]
185    name: Option<String>,
186    join_handle: tokio::task::JoinHandle<()>,
187}
188
189#[cfg(feature = "runtime-tokio")]
190impl<M> TokioTaskBackend<M>
191where
192    M: Send + 'static,
193{
194    fn new(dispatcher: Dispatcher<M>) -> Self {
195        let runtime = tokio::runtime::Builder::new_multi_thread()
196            .enable_all()
197            .build()
198            .expect("tokio runtime backend must initialize");
199        Self {
200            dispatcher,
201            runtime,
202            tasks: HashMap::new(),
203            shutting_down: false,
204        }
205    }
206
207    fn close(&mut self) {
208        if self.shutting_down {
209            return;
210        }
211        self.shutting_down = true;
212        self.cancel_all();
213    }
214}
215
216#[cfg(feature = "runtime-tokio")]
217impl<M> TaskBackend<M> for TokioTaskBackend<M>
218where
219    M: Send + 'static,
220{
221    fn spawn_boxed(&mut self, name: Option<String>, future: BoxedFuture<M>) -> TaskId {
222        let task_id = TaskId::new();
223        if self.shutting_down {
224            drop(future);
225            self.dispatcher.finish_task(task_id);
226            return task_id;
227        }
228
229        let dispatcher = self.dispatcher.clone();
230        let join_handle = self.runtime.spawn(async move {
231            let message = future.await;
232            let _ = dispatcher.post_message(message);
233            dispatcher.finish_task(task_id);
234        });
235        self.tasks
236            .insert(task_id, TokioTaskControl { name, join_handle });
237        task_id
238    }
239
240    fn cancel(&mut self, task_id: TaskId) -> bool {
241        if let Some(control) = self.tasks.remove(&task_id) {
242            control.join_handle.abort();
243            return true;
244        }
245        false
246    }
247
248    fn cancel_all(&mut self) {
249        for (_, control) in self.tasks.drain() {
250            control.join_handle.abort();
251        }
252    }
253
254    fn reap(&mut self, task_id: TaskId) {
255        let _ = self.tasks.remove(&task_id);
256    }
257
258    fn shutdown(&mut self) {
259        self.close();
260    }
261
262    #[cfg(test)]
263    fn active_count(&self) -> usize {
264        self.tasks.len()
265    }
266}
267
268#[cfg(feature = "runtime-tokio")]
269impl<M> Drop for TokioTaskBackend<M> {
270    fn drop(&mut self) {
271        self.shutting_down = true;
272        for (_, control) in self.tasks.drain() {
273            control.join_handle.abort();
274        }
275    }
276}