wire_framework/service/
mod.rs1use std::{collections::HashMap, time::Duration};
2
3use crate::utils::try_extract_panic_message;
4use futures::future::Fuse;
5use tokio::{runtime::Runtime, sync::watch, task::JoinHandle};
6
7pub use self::{
8 context::ServiceContext,
9 context_traits::{FromContext, IntoContext},
10 error::{ServiceError, TaskError},
11 shutdown_hook::ShutdownHook,
12 stop_receiver::StopReceiver,
13};
14use crate::{
15 resource::{ResourceId, StoredResource},
16 service::{
17 named_future::NamedFuture,
18 runnables::{NamedBoxFuture, Runnables, TaskReprs},
19 },
20 task::TaskId,
21 wiring_layer::{WireFn, WiringError, WiringLayer, WiringLayerExt},
22};
23
24mod context;
25mod context_traits;
26mod error;
27mod named_future;
28mod runnables;
29mod shutdown_hook;
30mod stop_receiver;
31#[cfg(test)]
32mod tests;
33
34const TASK_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
36
37#[derive(Debug)]
39pub struct ServiceBuilder {
40 layers: Vec<(&'static str, WireFn)>,
44 runtime: Runtime,
46}
47
48impl ServiceBuilder {
49 pub fn new() -> Result<Self, ServiceError> {
53 if tokio::runtime::Handle::try_current().is_ok() {
54 return Err(ServiceError::RuntimeDetected);
55 }
56 let runtime = tokio::runtime::Builder::new_multi_thread()
57 .enable_all()
58 .build()
59 .unwrap();
60 Ok(Self::on_runtime(runtime))
61 }
62
63 pub fn on_runtime(runtime: Runtime) -> Self {
69 Self {
70 layers: Vec::new(),
71 runtime,
72 }
73 }
74
75 pub fn runtime_handle(&self) -> tokio::runtime::Handle {
77 self.runtime.handle().clone()
78 }
79
80 pub fn add_layer<T: WiringLayer>(&mut self, layer: T) -> &mut Self {
90 let name = layer.layer_name();
91 if !self
92 .layers
93 .iter()
94 .any(|(existing_name, _)| name == *existing_name)
95 {
96 self.layers.push((name, layer.into_wire_fn()));
97 }
98 self
99 }
100
101 pub fn build(self) -> Service {
103 let (stop_sender, _stop_receiver) = watch::channel(false);
104
105 Service {
106 layers: self.layers,
107 resources: Default::default(),
108 runnables: Default::default(),
109 stop_sender,
110 runtime: self.runtime,
111 errors: Vec::new(),
112 }
113 }
114}
115
116#[derive(Debug)]
119pub struct Service {
120 resources: HashMap<ResourceId, Box<dyn StoredResource>>,
122 layers: Vec<(&'static str, WireFn)>,
124 runnables: Runnables,
126
127 stop_sender: watch::Sender<bool>,
129 runtime: Runtime,
131
132 errors: Vec<TaskError>,
134}
135
136type TaskFuture = NamedFuture<Fuse<JoinHandle<eyre::Result<()>>>>;
137
138impl Service {
139 pub fn run(self) -> Result<(), ServiceError> {
144 self.run_with_guard(())
145 }
146
147 pub fn run_with_guard<G>(mut self, observability_guard: G) -> Result<(), ServiceError> {
155 self.wire()?;
156
157 let TaskReprs {
158 tasks,
159 shutdown_hooks,
160 } = self.prepare_tasks();
161
162 let remaining = self.run_tasks(tasks);
163 self.shutdown_tasks(remaining);
164 self.run_shutdown_hooks(shutdown_hooks);
165
166 tracing::info!("Exiting the service");
167
168 if std::mem::needs_drop::<G>() {
169 let _guard = self.runtime.enter();
171 drop(observability_guard);
172 }
173
174 if self.errors.is_empty() {
175 Ok(())
176 } else {
177 Err(ServiceError::Task(self.errors.into()))
178 }
179 }
180
181 fn wire(&mut self) -> Result<(), ServiceError> {
184 let wiring_layers = std::mem::take(&mut self.layers);
186
187 let mut errors: Vec<(String, WiringError)> = Vec::new();
188
189 let runtime_handle = self.runtime.handle().clone();
190 for (name, WireFn(wire_fn)) in wiring_layers {
191 let mut context = ServiceContext::new(name, self);
193 let task_result = wire_fn(&runtime_handle, &mut context);
194 if let Err(err) = task_result {
195 errors.push((name.to_string(), err));
199 continue;
200 };
201 }
202
203 if !errors.is_empty() {
205 for (layer, error) in &errors {
206 tracing::error!("Wiring layer {layer} can't be initialized: {error:?}");
207 }
208 return Err(ServiceError::Wiring(errors));
209 }
210
211 if self.runnables.is_empty() {
212 return Err(ServiceError::NoTasks);
213 }
214
215 for resource in self.resources.values_mut() {
217 resource.stored_resource_wired();
218 }
219 self.resources = HashMap::default(); tracing::info!("Wiring complete");
221
222 Ok(())
223 }
224
225 fn prepare_tasks(&mut self) -> TaskReprs {
227 let task_barrier = self.runnables.task_barrier();
230
231 let stop_receiver = StopReceiver(self.stop_sender.subscribe());
233 self.runnables
234 .prepare_tasks(task_barrier.clone(), stop_receiver.clone())
235 }
236
237 fn run_tasks(&mut self, tasks: Vec<NamedBoxFuture<eyre::Result<()>>>) -> Vec<TaskFuture> {
241 let rt_handle = self.runtime.handle().clone();
243 let join_handles: Vec<_> = tasks
244 .into_iter()
245 .map(|task| task.spawn(&rt_handle).fuse())
246 .collect();
247
248 let mut tasks_names: Vec<_> = join_handles.iter().map(|task| task.id()).collect();
250
251 let (resolved, resolved_idx, remaining) = self
253 .runtime
254 .block_on(futures::future::select_all(join_handles));
255 let task_name = tasks_names.swap_remove(resolved_idx);
258 self.handle_task_exit(resolved, task_name);
259 tracing::info!("One of the task has exited, shutting down the node");
260
261 remaining
262 }
263
264 fn shutdown_tasks(&mut self, remaining: Vec<TaskFuture>) {
266 self.stop_sender.send(true).ok();
268
269 let remaining_tasks_names: Vec<_> = remaining.iter().map(|task| task.id()).collect();
272 let remaining_tasks_with_timeout: Vec<_> = remaining
273 .into_iter()
274 .map(|task| async { tokio::time::timeout(TASK_SHUTDOWN_TIMEOUT, task).await })
275 .collect();
276
277 let execution_results = self
278 .runtime
279 .block_on(futures::future::join_all(remaining_tasks_with_timeout));
280
281 for (name, result) in remaining_tasks_names.into_iter().zip(execution_results) {
283 match result {
284 Ok(resolved) => {
285 self.handle_task_exit(resolved, name);
286 }
287 Err(_) => {
288 tracing::error!("Task {name} timed out");
289 self.errors.push(TaskError::TaskShutdownTimedOut(name));
290 }
291 }
292 }
293 }
294
295 fn run_shutdown_hooks(&mut self, shutdown_hooks: Vec<NamedBoxFuture<eyre::Result<()>>>) {
297 for hook in shutdown_hooks {
299 let name = hook.id().clone();
300 let hook_with_timeout =
302 async move { tokio::time::timeout(TASK_SHUTDOWN_TIMEOUT, hook).await };
303 match self.runtime.block_on(hook_with_timeout) {
304 Ok(Ok(())) => {
305 tracing::info!("Shutdown hook {name} completed");
306 }
307 Ok(Err(err)) => {
308 tracing::error!("Shutdown hook {name} failed: {err:?}");
309 self.errors.push(TaskError::ShutdownHookFailed(name, err));
310 }
311 Err(_) => {
312 tracing::error!("Shutdown hook {name} timed out");
313 self.errors.push(TaskError::ShutdownHookTimedOut(name));
314 }
315 }
316 }
317 }
318
319 fn handle_task_exit(
321 &mut self,
322 task_result: Result<eyre::Result<()>, tokio::task::JoinError>,
323 task_name: TaskId,
324 ) {
325 match task_result {
326 Ok(Ok(())) => {
327 tracing::info!("Task {task_name} finished");
328 }
329 Ok(Err(err)) => {
330 tracing::error!("Task {task_name} failed: {err:?}");
331 self.errors.push(TaskError::TaskFailed(task_name, err));
332 }
333 Err(panic_err) => {
334 let panic_msg = try_extract_panic_message(panic_err);
335 tracing::error!("Task {task_name} panicked: {panic_msg}");
336 self.errors
337 .push(TaskError::TaskPanicked(task_name, panic_msg));
338 }
339 };
340 }
341}