1use crate::{config::TaskType, Error};
10use exit_future::Signal;
11use futures::{
12 future::{pending, select, try_join_all, BoxFuture, Either},
13 Future, FutureExt, StreamExt,
14};
15use parking_lot::Mutex;
16use soil_prometheus::{
17 exponential_buckets, register, CounterVec, HistogramOpts, HistogramVec, Opts, PrometheusError,
18 Registry, U64,
19};
20use soil_client::utils::mpsc::{
21 tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender,
22};
23use std::{
24 collections::{hash_map::Entry, HashMap},
25 panic,
26 pin::Pin,
27 result::Result,
28 sync::Arc,
29};
30use tokio::runtime::Handle;
31use tracing_futures::Instrument;
32
33mod prometheus_future;
34#[cfg(test)]
35mod tests;
36
37pub const DEFAULT_GROUP_NAME: &str = "default";
39
40pub enum GroupName {
45 Default,
47 Specific(&'static str),
49}
50
51impl From<Option<&'static str>> for GroupName {
52 fn from(name: Option<&'static str>) -> Self {
53 match name {
54 Some(name) => Self::Specific(name),
55 None => Self::Default,
56 }
57 }
58}
59
60impl From<&'static str> for GroupName {
61 fn from(name: &'static str) -> Self {
62 Self::Specific(name)
63 }
64}
65
66#[derive(Clone)]
68pub struct SpawnTaskHandle {
69 on_exit: exit_future::Exit,
70 tokio_handle: Handle,
71 metrics: Option<Metrics>,
72 task_registry: TaskRegistry,
73}
74
75impl SpawnTaskHandle {
76 pub fn spawn(
86 &self,
87 name: &'static str,
88 group: impl Into<GroupName>,
89 task: impl Future<Output = ()> + Send + 'static,
90 ) {
91 self.spawn_inner(name, group, task, TaskType::Async)
92 }
93
94 pub fn spawn_blocking(
96 &self,
97 name: &'static str,
98 group: impl Into<GroupName>,
99 task: impl Future<Output = ()> + Send + 'static,
100 ) {
101 self.spawn_inner(name, group, task, TaskType::Blocking)
102 }
103
104 fn spawn_inner(
106 &self,
107 name: &'static str,
108 group: impl Into<GroupName>,
109 task: impl Future<Output = ()> + Send + 'static,
110 task_type: TaskType,
111 ) {
112 let on_exit = self.on_exit.clone();
113 let metrics = self.metrics.clone();
114 let registry = self.task_registry.clone();
115
116 let group = match group.into() {
117 GroupName::Specific(var) => var,
118 GroupName::Default => DEFAULT_GROUP_NAME,
120 };
121
122 let task_type_label = match task_type {
123 TaskType::Blocking => "blocking",
124 TaskType::Async => "async",
125 };
126
127 if let Some(metrics) = &self.metrics {
130 metrics.tasks_spawned.with_label_values(&[name, group, task_type_label]).inc();
131 metrics
133 .tasks_ended
134 .with_label_values(&[name, "finished", group, task_type_label])
135 .inc_by(0);
136 }
137
138 let future = async move {
139 let _registry_token = registry.register_task(name, group);
142
143 if let Some(metrics) = metrics {
144 let task = {
146 let poll_duration =
147 metrics.poll_duration.with_label_values(&[name, group, task_type_label]);
148 let poll_start =
149 metrics.poll_start.with_label_values(&[name, group, task_type_label]);
150 let inner =
151 prometheus_future::with_poll_durations(poll_duration, poll_start, task);
152 panic::AssertUnwindSafe(inner).catch_unwind()
155 };
156 futures::pin_mut!(task);
157
158 match select(on_exit, task).await {
159 Either::Right((Err(payload), _)) => {
160 metrics
161 .tasks_ended
162 .with_label_values(&[name, "panic", group, task_type_label])
163 .inc();
164 panic::resume_unwind(payload)
165 },
166 Either::Right((Ok(()), _)) => {
167 metrics
168 .tasks_ended
169 .with_label_values(&[name, "finished", group, task_type_label])
170 .inc();
171 },
172 Either::Left(((), _)) => {
173 metrics
175 .tasks_ended
176 .with_label_values(&[name, "interrupted", group, task_type_label])
177 .inc();
178 },
179 }
180 } else {
181 futures::pin_mut!(task);
182 let _ = select(on_exit, task).await;
183 }
184 }
185 .in_current_span();
186
187 match task_type {
188 TaskType::Async => {
189 self.tokio_handle.spawn(future);
190 },
191 TaskType::Blocking => {
192 let handle = self.tokio_handle.clone();
193 self.tokio_handle.spawn_blocking(move || {
194 handle.block_on(future);
195 });
196 },
197 }
198 }
199}
200
201impl subsoil::core::traits::SpawnNamed for SpawnTaskHandle {
202 fn spawn_blocking(
203 &self,
204 name: &'static str,
205 group: Option<&'static str>,
206 future: BoxFuture<'static, ()>,
207 ) {
208 self.spawn_inner(name, group, future, TaskType::Blocking)
209 }
210
211 fn spawn(
212 &self,
213 name: &'static str,
214 group: Option<&'static str>,
215 future: BoxFuture<'static, ()>,
216 ) {
217 self.spawn_inner(name, group, future, TaskType::Async)
218 }
219}
220
221#[derive(Clone)]
226pub struct SpawnEssentialTaskHandle {
227 essential_failed_tx: TracingUnboundedSender<()>,
228 inner: SpawnTaskHandle,
229}
230
231impl SpawnEssentialTaskHandle {
232 pub fn new(
234 essential_failed_tx: TracingUnboundedSender<()>,
235 spawn_task_handle: SpawnTaskHandle,
236 ) -> SpawnEssentialTaskHandle {
237 SpawnEssentialTaskHandle { essential_failed_tx, inner: spawn_task_handle }
238 }
239
240 pub fn spawn(
244 &self,
245 name: &'static str,
246 group: impl Into<GroupName>,
247 task: impl Future<Output = ()> + Send + 'static,
248 ) {
249 self.spawn_inner(name, group, task, TaskType::Async)
250 }
251
252 pub fn spawn_blocking(
256 &self,
257 name: &'static str,
258 group: impl Into<GroupName>,
259 task: impl Future<Output = ()> + Send + 'static,
260 ) {
261 self.spawn_inner(name, group, task, TaskType::Blocking)
262 }
263
264 fn spawn_inner(
265 &self,
266 name: &'static str,
267 group: impl Into<GroupName>,
268 task: impl Future<Output = ()> + Send + 'static,
269 task_type: TaskType,
270 ) {
271 let essential_failed = self.essential_failed_tx.clone();
272 let essential_task = std::panic::AssertUnwindSafe(task).catch_unwind().map(move |_| {
273 log::error!("Essential task `{}` failed. Shutting down service.", name);
274 let _ = essential_failed.close();
275 });
276
277 let _ = self.inner.spawn_inner(name, group, essential_task, task_type);
278 }
279}
280
281impl subsoil::core::traits::SpawnEssentialNamed for SpawnEssentialTaskHandle {
282 fn spawn_essential_blocking(
283 &self,
284 name: &'static str,
285 group: Option<&'static str>,
286 future: BoxFuture<'static, ()>,
287 ) {
288 self.spawn_blocking(name, group, future);
289 }
290
291 fn spawn_essential(
292 &self,
293 name: &'static str,
294 group: Option<&'static str>,
295 future: BoxFuture<'static, ()>,
296 ) {
297 self.spawn(name, group, future);
298 }
299}
300
301pub struct TaskManager {
303 on_exit: exit_future::Exit,
306 _signal: Signal,
308 tokio_handle: Handle,
310 metrics: Option<Metrics>,
312 essential_failed_tx: TracingUnboundedSender<()>,
315 essential_failed_rx: TracingUnboundedReceiver<()>,
317 keep_alive: Box<dyn std::any::Any + Send>,
319 children: Vec<TaskManager>,
323 task_registry: TaskRegistry,
325}
326
327impl TaskManager {
328 pub fn new(
331 tokio_handle: Handle,
332 prometheus_registry: Option<&Registry>,
333 ) -> Result<Self, PrometheusError> {
334 let (signal, on_exit) = exit_future::signal();
335
336 let (essential_failed_tx, essential_failed_rx) =
338 tracing_unbounded("mpsc_essential_tasks", 100);
339
340 let metrics = prometheus_registry.map(Metrics::register).transpose()?;
341
342 Ok(Self {
343 on_exit,
344 _signal: signal,
345 tokio_handle,
346 metrics,
347 essential_failed_tx,
348 essential_failed_rx,
349 keep_alive: Box::new(()),
350 children: Vec::new(),
351 task_registry: Default::default(),
352 })
353 }
354
355 pub fn spawn_handle(&self) -> SpawnTaskHandle {
357 SpawnTaskHandle {
358 on_exit: self.on_exit.clone(),
359 tokio_handle: self.tokio_handle.clone(),
360 metrics: self.metrics.clone(),
361 task_registry: self.task_registry.clone(),
362 }
363 }
364
365 pub fn spawn_essential_handle(&self) -> SpawnEssentialTaskHandle {
367 SpawnEssentialTaskHandle::new(self.essential_failed_tx.clone(), self.spawn_handle())
368 }
369
370 pub fn future<'a>(
377 &'a mut self,
378 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
379 Box::pin(async move {
380 let mut t1 = self.essential_failed_rx.next().fuse();
381 let mut t2 = self.on_exit.clone().fuse();
382 let mut t3 = try_join_all(
383 self.children
384 .iter_mut()
385 .map(|x| x.future())
386 .chain(std::iter::once(pending().boxed())),
389 )
390 .fuse();
391
392 futures::select! {
393 _ = t1 => Err(Error::Other("Essential task failed.".into())),
394 _ = t2 => Ok(()),
395 res = t3 => Err(res.map(|_| ()).expect_err("this future never ends; qed")),
396 }
397 })
398 }
399
400 pub fn keep_alive<T: 'static + Send>(&mut self, to_keep_alive: T) {
402 use std::mem;
404 let old = mem::replace(&mut self.keep_alive, Box::new(()));
405 self.keep_alive = Box::new((to_keep_alive, old));
406 }
407
408 pub fn add_child(&mut self, child: TaskManager) {
412 self.children.push(child);
413 }
414
415 pub fn into_task_registry(self) -> TaskRegistry {
420 self.task_registry
421 }
422}
423
424#[derive(Clone)]
425struct Metrics {
426 poll_duration: HistogramVec,
428 poll_start: CounterVec<U64>,
429 tasks_spawned: CounterVec<U64>,
430 tasks_ended: CounterVec<U64>,
431}
432
433impl Metrics {
434 fn register(registry: &Registry) -> Result<Self, PrometheusError> {
435 Ok(Self {
436 poll_duration: register(HistogramVec::new(
437 HistogramOpts {
438 common_opts: Opts::new(
439 "substrate_tasks_polling_duration",
440 "Duration in seconds of each invocation of Future::poll"
441 ),
442 buckets: exponential_buckets(0.001, 4.0, 9)
443 .expect("function parameters are constant and always valid; qed"),
444 },
445 &["task_name", "task_group", "kind"]
446 )?, registry)?,
447 poll_start: register(CounterVec::new(
448 Opts::new(
449 "substrate_tasks_polling_started_total",
450 "Total number of times we started invoking Future::poll"
451 ),
452 &["task_name", "task_group", "kind"]
453 )?, registry)?,
454 tasks_spawned: register(CounterVec::new(
455 Opts::new(
456 "substrate_tasks_spawned_total",
457 "Total number of tasks that have been spawned on the Service"
458 ),
459 &["task_name", "task_group", "kind"]
460 )?, registry)?,
461 tasks_ended: register(CounterVec::new(
462 Opts::new(
463 "substrate_tasks_ended_total",
464 "Total number of tasks for which Future::poll has returned Ready(()) or panicked"
465 ),
466 &["task_name", "reason", "task_group", "kind"]
467 )?, registry)?,
468 })
469 }
470}
471
472struct UnregisterOnDrop {
474 task: Task,
475 registry: TaskRegistry,
476}
477
478impl Drop for UnregisterOnDrop {
479 fn drop(&mut self) {
480 let mut tasks = self.registry.tasks.lock();
481
482 if let Entry::Occupied(mut entry) = (*tasks).entry(self.task.clone()) {
483 *entry.get_mut() -= 1;
484
485 if *entry.get() == 0 {
486 entry.remove();
487 }
488 }
489 }
490}
491
492#[derive(Clone, Hash, Eq, PartialEq)]
497pub struct Task {
498 pub name: &'static str,
500 pub group: &'static str,
502}
503
504impl Task {
505 pub fn is_default_group(&self) -> bool {
507 self.group == DEFAULT_GROUP_NAME
508 }
509}
510
511#[derive(Clone, Default)]
513pub struct TaskRegistry {
514 tasks: Arc<Mutex<HashMap<Task, usize>>>,
515}
516
517impl TaskRegistry {
518 fn register_task(&self, name: &'static str, group: &'static str) -> UnregisterOnDrop {
523 let task = Task { name, group };
524
525 {
526 let mut tasks = self.tasks.lock();
527
528 *(*tasks).entry(task.clone()).or_default() += 1;
529 }
530
531 UnregisterOnDrop { task, registry: self.clone() }
532 }
533
534 pub fn running_tasks(&self) -> HashMap<Task, usize> {
539 (*self.tasks.lock()).clone()
540 }
541}