watchexec_supervisor/job/
task.rs

1use std::{future::Future, mem::take, sync::Arc, time::Instant};
2
3use process_wrap::tokio::TokioCommandWrap;
4use tokio::{select, task::JoinHandle};
5use tracing::{instrument, trace, trace_span, Instrument};
6use watchexec_signals::Signal;
7
8use crate::{
9	command::Command,
10	errors::{sync_io_error, SyncIoError},
11	flag::Flag,
12	job::priority::Timer,
13};
14
15use super::{
16	job::Job,
17	messages::{Control, ControlMessage},
18	priority,
19	state::CommandState,
20};
21
22/// Spawn a job task and return a [`Job`] handle and a [`JoinHandle`].
23///
24/// The job task immediately starts in the background: it does not need polling.
25#[must_use]
26#[instrument(level = "trace")]
27pub fn start_job(command: Arc<Command>) -> (Job, JoinHandle<()>) {
28	enum Loop {
29		Normally,
30		Skip,
31		Break,
32	}
33
34	let (sender, mut receiver) = priority::new();
35	let gone = Flag::default();
36	let done = gone.clone();
37
38	(
39		Job {
40			command: command.clone(),
41			control_queue: sender,
42			gone,
43		},
44		tokio::spawn(async move {
45			let mut error_handler = ErrorHandler::None;
46			let mut spawn_hook = SpawnHook::None;
47			let mut command_state = CommandState::Pending;
48			let mut previous_run = None;
49			let mut stop_timer = None;
50			let mut on_end: Vec<Flag> = Vec::new();
51			let mut on_end_restart: Option<Flag> = None;
52
53			'main: loop {
54				select! {
55					result = command_state.wait(), if command_state.is_running() => {
56						trace!(?result, ?command_state, "got wait result");
57						match async {
58							#[cfg(test)] eprintln!("[{:?}] waited: {result:?}", Instant::now());
59
60							match result {
61								Err(err) => {
62									let fut = error_handler.call(sync_io_error(err));
63									fut.await;
64									return Loop::Skip;
65								}
66								Ok(true) => {
67									trace!(existing=?stop_timer, "erasing stop timer");
68									stop_timer = None;
69									trace!(count=%on_end.len(), "raising all pending end flags");
70									for done in take(&mut on_end) {
71										done.raise();
72									}
73
74									if let Some(flag) = on_end_restart.take() {
75										trace!("continuing a graceful restart");
76
77										let mut spawnable = command.to_spawnable();
78										previous_run = Some(command_state.reset());
79										spawn_hook
80											.call(
81												&mut spawnable,
82												&JobTaskContext {
83													command: command.clone(),
84													current: &command_state,
85													previous: previous_run.as_ref(),
86												},
87											)
88											.await;
89										if let Err(err) = command_state.spawn(command.clone(), spawnable) {
90											let fut = error_handler.call(sync_io_error(err));
91											fut.await;
92											return Loop::Skip;
93										}
94
95										trace!("raising graceful restart's flag");
96										flag.raise();
97									}
98								}
99								Ok(false) => {
100									trace!("child wasn't running, ignoring wait result");
101								}
102							}
103
104							Loop::Normally
105						}.instrument(trace_span!("handle wait result")).await {
106							Loop::Normally => {}
107							Loop::Skip => {
108								trace!("skipping to next event");
109								continue 'main;
110							}
111							Loop::Break => {
112								trace!("breaking out of main loop");
113								break 'main;
114							}
115						}
116					}
117					Some(ControlMessage { control, done }) = receiver.recv(&mut stop_timer) => {
118						match async {
119							trace!(?control, ?command_state, "got control message");
120							#[cfg(test)] eprintln!("[{:?}] control: {control:?}", Instant::now());
121
122							macro_rules! try_with_handler {
123								($erroring:expr) => {
124									match $erroring {
125										Err(err) => {
126											let fut = error_handler.call(sync_io_error(err));
127											fut.await;
128											trace!("raising done flag for this control after error");
129											done.raise();
130											return Loop::Normally;
131										}
132										Ok(value) => value,
133									}
134								};
135							}
136
137							match control {
138								Control::Start => {
139									if command_state.is_running() {
140										trace!("child is running, skip");
141									} else {
142										let mut spawnable = command.to_spawnable();
143										previous_run = Some(command_state.reset());
144										spawn_hook
145											.call(
146												&mut spawnable,
147												&JobTaskContext {
148													command: command.clone(),
149													current: &command_state,
150													previous: previous_run.as_ref(),
151												},
152											)
153											.await;
154										try_with_handler!(command_state.spawn(command.clone(), spawnable));
155									}
156								}
157								Control::Stop => {
158									if let CommandState::Running { child, started, .. } = &mut command_state {
159										trace!("stopping child");
160										try_with_handler!(Box::into_pin(child.kill()).await);
161										trace!("waiting on child");
162										let status = try_with_handler!(Box::into_pin(child.wait()).await);
163
164										trace!(?status, "got child end status");
165										command_state = CommandState::Finished {
166											status: status.into(),
167											started: *started,
168											finished: Instant::now(),
169										};
170
171										trace!(count=%on_end.len(), "raising all pending end flags");
172										for done in take(&mut on_end) {
173											done.raise();
174										}
175									} else {
176										trace!("child isn't running, skip");
177									}
178								}
179								Control::GracefulStop { signal, grace } => {
180									if let CommandState::Running { child, .. } = &mut command_state {
181										try_with_handler!(signal_child(signal, child).await);
182
183										trace!(?grace, "setting up graceful stop timer");
184										stop_timer.replace(Timer::stop(grace, done));
185										return Loop::Skip;
186									}
187									trace!("child isn't running, skip");
188								}
189								Control::TryRestart => {
190									if let CommandState::Running { child, started, .. } = &mut command_state {
191										trace!("stopping child");
192										try_with_handler!(Box::into_pin(child.kill()).await);
193										trace!("waiting on child");
194										let status = try_with_handler!(Box::into_pin(child.wait()).await);
195
196										trace!(?status, "got child end status");
197										command_state = CommandState::Finished {
198											status: status.into(),
199											started: *started,
200											finished: Instant::now(),
201										};
202										previous_run = Some(command_state.reset());
203
204										trace!(count=%on_end.len(), "raising all pending end flags");
205										for done in take(&mut on_end) {
206											done.raise();
207										}
208
209										let mut spawnable = command.to_spawnable();
210										spawn_hook
211											.call(
212												&mut spawnable,
213												&JobTaskContext {
214													command: command.clone(),
215													current: &command_state,
216													previous: previous_run.as_ref(),
217												},
218											)
219											.await;
220										try_with_handler!(command_state.spawn(command.clone(), spawnable));
221									} else {
222										trace!("child isn't running, skip");
223									}
224								}
225								Control::TryGracefulRestart { signal, grace } => {
226									if let CommandState::Running { child, .. } = &mut command_state {
227										try_with_handler!(signal_child(signal, child).await);
228
229										trace!(?grace, "setting up graceful stop timer");
230										stop_timer.replace(Timer::restart(grace, done.clone()));
231										trace!("setting up graceful restart flag");
232										on_end_restart = Some(done);
233										return Loop::Skip;
234									}
235									trace!("child isn't running, skip");
236								}
237								Control::ContinueTryGracefulRestart => {
238									trace!("continuing a graceful try-restart");
239
240									if let CommandState::Running { child, started, .. } = &mut command_state {
241										trace!("stopping child forcefully");
242										try_with_handler!(Box::into_pin(child.kill()).await);
243										trace!("waiting on child");
244										let status = try_with_handler!(Box::into_pin(child.wait()).await);
245
246										trace!(?status, "got child end status");
247										command_state = CommandState::Finished {
248											status: status.into(),
249											started: *started,
250											finished: Instant::now(),
251										};
252
253										trace!(count=%on_end.len(), "raising all pending end flags");
254										for done in take(&mut on_end) {
255											done.raise();
256										}
257									}
258
259									let mut spawnable = command.to_spawnable();
260									previous_run = Some(command_state.reset());
261									spawn_hook
262										.call(
263											&mut spawnable,
264											&JobTaskContext {
265												command: command.clone(),
266												current: &command_state,
267												previous: previous_run.as_ref(),
268											},
269										)
270										.await;
271									try_with_handler!(command_state.spawn(command.clone(), spawnable));
272								}
273								Control::Signal(signal) => {
274									if let CommandState::Running { child, .. } = &mut command_state {
275										try_with_handler!(signal_child(signal, child).await);
276									} else {
277										trace!("child isn't running, skip");
278									}
279								}
280								Control::Delete => {
281									trace!("raising done flag immediately");
282									done.raise();
283									return Loop::Break;
284								}
285
286								Control::NextEnding => {
287									if matches!(command_state, CommandState::Finished { .. }) {
288										trace!("child is finished, raise done flag immediately");
289										done.raise();
290										return Loop::Normally;
291									}
292										trace!("queue end flag");
293										on_end.push(done);
294										return Loop::Skip;
295								}
296
297								Control::SyncFunc(f) => {
298									f(&JobTaskContext {
299										command: command.clone(),
300										current: &command_state,
301										previous: previous_run.as_ref(),
302									});
303								}
304								Control::AsyncFunc(f) => {
305									Box::into_pin(f(&JobTaskContext {
306										command: command.clone(),
307										current: &command_state,
308										previous: previous_run.as_ref(),
309									}))
310									.await;
311								}
312
313								Control::SetSyncErrorHandler(f) => {
314									trace!("setting sync error handler");
315									error_handler = ErrorHandler::Sync(f);
316								}
317								Control::SetAsyncErrorHandler(f) => {
318									trace!("setting async error handler");
319									error_handler = ErrorHandler::Async(f);
320								}
321								Control::UnsetErrorHandler => {
322									trace!("unsetting error handler");
323									error_handler = ErrorHandler::None;
324								}
325								Control::SetSyncSpawnHook(f) => {
326									trace!("setting sync spawn hook");
327									spawn_hook = SpawnHook::Sync(f);
328								}
329								Control::SetAsyncSpawnHook(f) => {
330									trace!("setting async spawn hook");
331									spawn_hook = SpawnHook::Async(f);
332								}
333								Control::UnsetSpawnHook => {
334									trace!("unsetting spawn hook");
335									spawn_hook = SpawnHook::None;
336								}
337							}
338
339							trace!("raising control done flag");
340							done.raise();
341
342							Loop::Normally
343					}.instrument(trace_span!("handle control message")).await {
344						Loop::Normally => {}
345						Loop::Skip => {
346							trace!("skipping to next event (without raising done flag)");
347							continue 'main;
348						}
349						Loop::Break => {
350							trace!("breaking out of main loop");
351							break 'main;
352						}
353					}
354				}
355				}
356			}
357
358			trace!("raising job done flag");
359			done.raise();
360		}),
361	)
362}
363
364macro_rules! sync_async_callbox {
365	($name:ident, $synct:ty, $asynct:ty, ($($argname:ident : $argtype:ty),*)) => {
366		pub enum $name {
367			None,
368			Sync($synct),
369			Async($asynct),
370		}
371
372		impl $name {
373			#[instrument(level = "trace", skip(self, $($argname),*))]
374			pub async fn call(&self, $($argname: $argtype),*) {
375				match self {
376					$name::None => (),
377					$name::Sync(f) => {
378						::tracing::trace!("calling sync {:?}", stringify!($name));
379						f($($argname),*)
380					}
381					$name::Async(f) => {
382						::tracing::trace!("calling async {:?}", stringify!($name));
383						Box::into_pin(f($($argname),*)).await
384					}
385				}
386			}
387		}
388	};
389}
390
391/// Job task internals exposed via hooks.
392#[derive(Debug)]
393pub struct JobTaskContext<'task> {
394	/// The job's [`Command`].
395	pub command: Arc<Command>,
396
397	/// The current state of the job.
398	pub current: &'task CommandState,
399
400	/// The state of the previous iteration of the job, if any.
401	///
402	/// This is generally [`CommandState::Finished`], but may be other states in rare cases.
403	pub previous: Option<&'task CommandState>,
404}
405
406pub type SyncFunc = Box<dyn FnOnce(&JobTaskContext<'_>) + Send + Sync + 'static>;
407pub type AsyncFunc = Box<
408	dyn (FnOnce(&JobTaskContext<'_>) -> Box<dyn Future<Output = ()> + Send + Sync>)
409		+ Send
410		+ Sync
411		+ 'static,
412>;
413
414pub type SyncSpawnHook =
415	Arc<dyn Fn(&mut TokioCommandWrap, &JobTaskContext<'_>) + Send + Sync + 'static>;
416pub type AsyncSpawnHook = Arc<
417	dyn (Fn(&mut TokioCommandWrap, &JobTaskContext<'_>) -> Box<dyn Future<Output = ()> + Send + Sync>)
418		+ Send
419		+ Sync
420		+ 'static,
421>;
422
423sync_async_callbox!(SpawnHook, SyncSpawnHook, AsyncSpawnHook, (command: &mut TokioCommandWrap, context: &JobTaskContext<'_>));
424
425pub type SyncErrorHandler = Arc<dyn Fn(SyncIoError) + Send + Sync + 'static>;
426pub type AsyncErrorHandler = Arc<
427	dyn (Fn(SyncIoError) -> Box<dyn Future<Output = ()> + Send + Sync>) + Send + Sync + 'static,
428>;
429
430sync_async_callbox!(ErrorHandler, SyncErrorHandler, AsyncErrorHandler, (error: SyncIoError));
431
432#[cfg_attr(not(windows), allow(clippy::needless_pass_by_ref_mut))] // needed for start_kill()
433#[instrument(level = "trace")]
434async fn signal_child(
435	signal: Signal,
436	#[cfg(not(test))] child: &mut Box<dyn process_wrap::tokio::TokioChildWrapper>,
437	#[cfg(test)] child: &mut super::TestChild,
438) -> std::io::Result<()> {
439	#[cfg(unix)]
440	{
441		let sig = signal
442			.to_nix()
443			.or_else(|| Signal::Terminate.to_nix())
444			.expect("UNWRAP: guaranteed for Signal::Terminate default");
445		trace!(signal=?sig, "sending signal");
446		child.signal(sig as _)?;
447	}
448
449	#[cfg(windows)]
450	if signal == Signal::ForceStop {
451		trace!("starting kill, without waiting");
452		child.start_kill()?;
453	} else {
454		trace!(?signal, "ignoring unsupported signal");
455	}
456
457	Ok(())
458}