wire_framework/service/
mod.rs

1use std::{collections::HashMap, time::Duration};
2
3use crate::utils::try_extract_panic_message;
4use futures::future::Fuse;
5use tokio::{runtime::Runtime, sync::watch, task::JoinHandle};
6
7pub use self::{
8    context::ServiceContext,
9    context_traits::{FromContext, IntoContext},
10    error::{ServiceError, TaskError},
11    shutdown_hook::ShutdownHook,
12    stop_receiver::StopReceiver,
13};
14use crate::{
15    resource::{ResourceId, StoredResource},
16    service::{
17        named_future::NamedFuture,
18        runnables::{NamedBoxFuture, Runnables, TaskReprs},
19    },
20    task::TaskId,
21    wiring_layer::{WireFn, WiringError, WiringLayer, WiringLayerExt},
22};
23
24mod context;
25mod context_traits;
26mod error;
27mod named_future;
28mod runnables;
29mod shutdown_hook;
30mod stop_receiver;
31#[cfg(test)]
32mod tests;
33
34// A reasonable amount of time for any task to finish the shutdown process
35const TASK_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
36
37/// A builder for [`Service`].
38#[derive(Debug)]
39pub struct ServiceBuilder {
40    /// List of wiring layers.
41    // Note: It has to be a `Vec` and not e.g. `HashMap` because the order in which we
42    // iterate through it matters.
43    layers: Vec<(&'static str, WireFn)>,
44    /// Tokio runtime used to spawn tasks.
45    runtime: Runtime,
46}
47
48impl ServiceBuilder {
49    /// Creates a new builder.
50    ///
51    /// Returns an error if called within a Tokio runtime context.
52    pub fn new() -> Result<Self, ServiceError> {
53        if tokio::runtime::Handle::try_current().is_ok() {
54            return Err(ServiceError::RuntimeDetected);
55        }
56        let runtime = tokio::runtime::Builder::new_multi_thread()
57            .enable_all()
58            .build()
59            .unwrap();
60        Ok(Self::on_runtime(runtime))
61    }
62
63    /// Creates a new builder with the provided Tokio runtime.
64    /// This method can be used if asynchronous tasks must be performed before the service is built.
65    ///
66    /// However, it is not recommended to use this method to spawn any tasks that will not be managed
67    /// by the service itself, so whenever it can be avoided, using [`ServiceBuilder::new`] is preferred.
68    pub fn on_runtime(runtime: Runtime) -> Self {
69        Self {
70            layers: Vec::new(),
71            runtime,
72        }
73    }
74
75    /// Returns a handle to the Tokio runtime used by the service.
76    pub fn runtime_handle(&self) -> tokio::runtime::Handle {
77        self.runtime.handle().clone()
78    }
79
80    /// Adds a wiring layer.
81    ///
82    /// During the [`run`](Service::run) call the service will invoke
83    /// `wire` method of every layer in the order they were added.
84    ///
85    /// This method may be invoked multiple times with the same layer type, but the
86    /// layer will only be stored once (meaning that 2nd attempt to add the same layer will be ignored).
87    /// This may be useful if the same layer is a prerequisite for multiple other layers: it is safe
88    /// to add it multiple times, and it will only be wired once.
89    pub fn add_layer<T: WiringLayer>(&mut self, layer: T) -> &mut Self {
90        let name = layer.layer_name();
91        if !self
92            .layers
93            .iter()
94            .any(|(existing_name, _)| name == *existing_name)
95        {
96            self.layers.push((name, layer.into_wire_fn()));
97        }
98        self
99    }
100
101    /// Builds the service.
102    pub fn build(self) -> Service {
103        let (stop_sender, _stop_receiver) = watch::channel(false);
104
105        Service {
106            layers: self.layers,
107            resources: Default::default(),
108            runnables: Default::default(),
109            stop_sender,
110            runtime: self.runtime,
111            errors: Vec::new(),
112        }
113    }
114}
115
116/// "Manager" class for a set of tasks. Collects all the resources and tasks,
117/// then runs tasks until completion.
118#[derive(Debug)]
119pub struct Service {
120    /// Cache of resources that have been requested at least by one task.
121    resources: HashMap<ResourceId, Box<dyn StoredResource>>,
122    /// List of wiring layers.
123    layers: Vec<(&'static str, WireFn)>,
124    /// Different kinds of tasks for the service.
125    runnables: Runnables,
126
127    /// Sender used to stop the tasks.
128    stop_sender: watch::Sender<bool>,
129    /// Tokio runtime used to spawn tasks.
130    runtime: Runtime,
131
132    /// Collector for the task errors met during the service execution.
133    errors: Vec<TaskError>,
134}
135
136type TaskFuture = NamedFuture<Fuse<JoinHandle<eyre::Result<()>>>>;
137
138impl Service {
139    /// Runs the system.
140    ///
141    /// In case of errors during wiring phase, will return the list of all the errors that happened, in the order
142    /// of their occurrence.
143    pub fn run(self) -> Result<(), ServiceError> {
144        self.run_with_guard(())
145    }
146
147    /// Runs the system.
148    ///
149    /// In case of errors during wiring phase, will return the list of all the errors that happened, in the order
150    /// of their occurrence.
151    ///
152    /// `observability_guard` will be used to deinitialize the observability subsystem
153    /// as the very last step before exiting the node.
154    pub fn run_with_guard<G>(mut self, observability_guard: G) -> Result<(), ServiceError> {
155        self.wire()?;
156
157        let TaskReprs {
158            tasks,
159            shutdown_hooks,
160        } = self.prepare_tasks();
161
162        let remaining = self.run_tasks(tasks);
163        self.shutdown_tasks(remaining);
164        self.run_shutdown_hooks(shutdown_hooks);
165
166        tracing::info!("Exiting the service");
167
168        if std::mem::needs_drop::<G>() {
169            // Make sure that the shutdown happens in the `tokio` context.
170            let _guard = self.runtime.enter();
171            drop(observability_guard);
172        }
173
174        if self.errors.is_empty() {
175            Ok(())
176        } else {
177            Err(ServiceError::Task(self.errors.into()))
178        }
179    }
180
181    /// Performs wiring of the service.
182    /// After invoking this method, the collected tasks will be collected in `self.runnables`.
183    fn wire(&mut self) -> Result<(), ServiceError> {
184        // Initialize tasks.
185        let wiring_layers = std::mem::take(&mut self.layers);
186
187        let mut errors: Vec<(String, WiringError)> = Vec::new();
188
189        let runtime_handle = self.runtime.handle().clone();
190        for (name, WireFn(wire_fn)) in wiring_layers {
191            // We must process wiring layers sequentially and in the same order as they were added.
192            let mut context = ServiceContext::new(name, self);
193            let task_result = wire_fn(&runtime_handle, &mut context);
194            if let Err(err) = task_result {
195                // We don't want to bail on the first error, since it'll provide worse DevEx:
196                // People likely want to fix as much problems as they can in one go, rather than have
197                // to fix them one by one.
198                errors.push((name.to_string(), err));
199                continue;
200            };
201        }
202
203        // Report all the errors we've met during the init.
204        if !errors.is_empty() {
205            for (layer, error) in &errors {
206                tracing::error!("Wiring layer {layer} can't be initialized: {error:?}");
207            }
208            return Err(ServiceError::Wiring(errors));
209        }
210
211        if self.runnables.is_empty() {
212            return Err(ServiceError::NoTasks);
213        }
214
215        // Wiring is now complete.
216        for resource in self.resources.values_mut() {
217            resource.stored_resource_wired();
218        }
219        self.resources = HashMap::default(); // Decrement reference counters for resources.
220        tracing::info!("Wiring complete");
221
222        Ok(())
223    }
224
225    /// Prepares collected tasks for running.
226    fn prepare_tasks(&mut self) -> TaskReprs {
227        // Barrier that will only be lifted once all the preconditions are met.
228        // It will be awaited by the tasks before they start running and by the preconditions once they are fulfilled.
229        let task_barrier = self.runnables.task_barrier();
230
231        // Collect long-running tasks.
232        let stop_receiver = StopReceiver(self.stop_sender.subscribe());
233        self.runnables
234            .prepare_tasks(task_barrier.clone(), stop_receiver.clone())
235    }
236
237    /// Spawn the provided tasks and runs them until at least one task exits, and returns the list
238    /// of remaining tasks.
239    /// Adds error, if any, to the `errors` vector.
240    fn run_tasks(&mut self, tasks: Vec<NamedBoxFuture<eyre::Result<()>>>) -> Vec<TaskFuture> {
241        // Prepare tasks for running.
242        let rt_handle = self.runtime.handle().clone();
243        let join_handles: Vec<_> = tasks
244            .into_iter()
245            .map(|task| task.spawn(&rt_handle).fuse())
246            .collect();
247
248        // Collect names for remaining tasks for reporting purposes.
249        let mut tasks_names: Vec<_> = join_handles.iter().map(|task| task.id()).collect();
250
251        // Run the tasks until one of them exits.
252        let (resolved, resolved_idx, remaining) = self
253            .runtime
254            .block_on(futures::future::select_all(join_handles));
255        // Extract the result and report it to logs early, before waiting for any other task to shutdown.
256        // We will also collect the errors from the remaining tasks, hence a vector.
257        let task_name = tasks_names.swap_remove(resolved_idx);
258        self.handle_task_exit(resolved, task_name);
259        tracing::info!("One of the task has exited, shutting down the node");
260
261        remaining
262    }
263
264    /// Sends the stop signal and waits for the remaining tasks to finish.
265    fn shutdown_tasks(&mut self, remaining: Vec<TaskFuture>) {
266        // Send stop signal to remaining tasks and wait for them to finish.
267        self.stop_sender.send(true).ok();
268
269        // Collect names for remaining tasks for reporting purposes.
270        // We have to re-collect, becuase `select_all` does not guarantes the order of returned remaining futures.
271        let remaining_tasks_names: Vec<_> = remaining.iter().map(|task| task.id()).collect();
272        let remaining_tasks_with_timeout: Vec<_> = remaining
273            .into_iter()
274            .map(|task| async { tokio::time::timeout(TASK_SHUTDOWN_TIMEOUT, task).await })
275            .collect();
276
277        let execution_results = self
278            .runtime
279            .block_on(futures::future::join_all(remaining_tasks_with_timeout));
280
281        // Report the results of the remaining tasks.
282        for (name, result) in remaining_tasks_names.into_iter().zip(execution_results) {
283            match result {
284                Ok(resolved) => {
285                    self.handle_task_exit(resolved, name);
286                }
287                Err(_) => {
288                    tracing::error!("Task {name} timed out");
289                    self.errors.push(TaskError::TaskShutdownTimedOut(name));
290                }
291            }
292        }
293    }
294
295    /// Runs the provided shutdown hooks.
296    fn run_shutdown_hooks(&mut self, shutdown_hooks: Vec<NamedBoxFuture<eyre::Result<()>>>) {
297        // Run shutdown hooks sequentially.
298        for hook in shutdown_hooks {
299            let name = hook.id().clone();
300            // Limit each shutdown hook to the same timeout as the tasks.
301            let hook_with_timeout =
302                async move { tokio::time::timeout(TASK_SHUTDOWN_TIMEOUT, hook).await };
303            match self.runtime.block_on(hook_with_timeout) {
304                Ok(Ok(())) => {
305                    tracing::info!("Shutdown hook {name} completed");
306                }
307                Ok(Err(err)) => {
308                    tracing::error!("Shutdown hook {name} failed: {err:?}");
309                    self.errors.push(TaskError::ShutdownHookFailed(name, err));
310                }
311                Err(_) => {
312                    tracing::error!("Shutdown hook {name} timed out");
313                    self.errors.push(TaskError::ShutdownHookTimedOut(name));
314                }
315            }
316        }
317    }
318
319    /// Checks the result of the task execution, logs the result, and stores the error if any.
320    fn handle_task_exit(
321        &mut self,
322        task_result: Result<eyre::Result<()>, tokio::task::JoinError>,
323        task_name: TaskId,
324    ) {
325        match task_result {
326            Ok(Ok(())) => {
327                tracing::info!("Task {task_name} finished");
328            }
329            Ok(Err(err)) => {
330                tracing::error!("Task {task_name} failed: {err:?}");
331                self.errors.push(TaskError::TaskFailed(task_name, err));
332            }
333            Err(panic_err) => {
334                let panic_msg = try_extract_panic_message(panic_err);
335                tracing::error!("Task {task_name} panicked: {panic_msg}");
336                self.errors
337                    .push(TaskError::TaskPanicked(task_name, panic_msg));
338            }
339        };
340    }
341}