Skip to main content

zng_task/process/
worker.rs

1//! Async worker process tasks.
2//!
3//! This module defines a worker process that can run tasks in a separate process instance.
4//!
5//! Each worker process can run multiple tasks in parallel, the worker type is [`Worker`]. Note that this module does not offer a fork
6//! implementation, the worker processes begin from the start state. The primary use of process tasks is to make otherwise fatal tasks
7//! recoverable, if the task calls unsafe code or code that can potentially terminate the entire process it should run using a [`Worker`].
8//! If you only want to recover from panics in safe code consider using [`task::run_catch`] or [`task::wait_catch`] instead.
9//!
10//! You can send [IPC channel] endpoints in the task request messages, this can be useful for implementing progress reporting,
11//! you can also send [`IpcBytes`] to efficiently share large byte blobs with the worker process.
12//!
13//! [`task::run_catch`]: crate::run_catch
14//! [`task::wait_catch`]: crate::wait_catch
15//! [IPC channel]: crate::channel::ipc_unbounded
16//! [`IpcBytes`]: crate::channel::IpcBytes
17//!
18//! # Examples
19//!
20//! The example below demonstrates a worker-process setup that uses the same executable as the app-process.
21//!
22//! ```
23//! # fn main() { }
24//! # mod zng { pub mod env { pub use zng_env::*; } pub mod task { pub use zng_task::*; } }
25//! # fn demo() {
26//! fn main() {
27//!     zng::env::init!();
28//!     // normal app init..
29//!     # zng::task::doc_test(false, on_click());
30//! }
31//! # }
32//!
33//! mod task1 {
34//! # use super::zng;
35//!     use zng::{task::process::worker, env};
36//!
37//!     const NAME: &str = "zng::example::task1";
38//!
39//!     env::on_process_start!(|args| {
40//!         // give tracing handlers a chance to observe the worker-process
41//!         if args.yield_count == 0 { return args.yield_once(); }
42//          // run the worker server
43//!         worker::run_worker(NAME, work);
44//!     });
45//!     async fn work(args: worker::RequestArgs<Request>) -> Response {
46//!         let rsp = format!("received 'task1' request `{:?}` in worker-process #{}", &args.request.data, std::process::id());
47//!         Response { data: rsp }
48//!     }
49//!     
50//!     #[derive(Debug, serde::Serialize, serde::Deserialize)]
51//!     pub struct Request { pub data: String }
52//!
53//!     #[derive(Debug, serde::Serialize, serde::Deserialize)]
54//!     pub struct Response { pub data: String }
55//!
56//!     // called in app-process
57//!     pub async fn start() -> worker::Worker<Request, Response> {
58//!         worker::Worker::start(NAME).await.expect("cannot spawn 'task1'")
59//!     }
60//! }
61//!
62//! // This runs in the app-process, it starts a worker process and requests a task run.
63//! async fn on_click() {
64//!     println!("app-process #{} starting a worker", std::process::id());
65//!     let mut worker = task1::start().await;
66//!     // request a task run and await it.
67//!     match worker.run(task1::Request { data: "request".to_owned() }).await {
68//!         Ok(task1::Response { data }) => println!("ok. {data}"),
69//!         Err(e) => eprintln!("error: {e}"),
70//!     }
71//!     // multiple tasks can be requested in parallel, use `task::all!` to await ..
72//!
73//!     // the worker process can be gracefully shutdown, awaits all pending tasks.
74//!     let _ = worker.shutdown().await;
75//! }
76//! ```
77//!
78//! Note that you can setup multiple workers the same executable, as long as the `on_process_start!` call happens
79//! on different modules.
80//!
81//! # Connect Timeout
82//!
83//! If the worker process takes longer than 10 seconds to connect the tasks fails. This is more then enough in most cases, but
84//! it can be too little in some test runner machines. You can set the `"ZNG_TASK_WORKER_TIMEOUT"` environment variable to a custom
85//! timeout in seconds. The minimum value is 1 second, set to 0 or empty use the default timeout.
86
87use core::fmt;
88use std::{marker::PhantomData, path::PathBuf, pin::Pin, process::Stdio, sync::Arc};
89
90use parking_lot::Mutex;
91use zng_clone_move::{async_clmv, clmv};
92use zng_txt::Txt;
93use zng_unique_id::IdMap;
94use zng_unit::TimeUnits as _;
95
96use crate::{
97    TaskPanicError,
98    channel::{self, ChannelError, IpcReceiver, IpcSender, IpcValue, NamedIpcSender},
99    process::tap::{StderrTap, contains_ansi_csi, remove_ansi_csi},
100};
101
102use super::tap::PanicInfo;
103
104const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
105const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
106const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
107
108const WORKER_TIMEOUT: &str = "ZNG_TASK_WORKER_TIMEOUT";
109
110/// The *App Process* and *Worker Process* must be build using the same exact version and this is
111/// validated during run-time, causing a panic if the versions don't match.
112pub const VERSION: &str = env!("CARGO_PKG_VERSION");
113
114/// Represents a running worker process.
115pub struct Worker<I: IpcValue, O: IpcValue> {
116    running: Option<(std::thread::JoinHandle<()>, std::process::Child, StderrTap)>,
117
118    sender: IpcSender<(RequestId, Request<I>)>,
119    requests: Arc<Mutex<IdMap<RequestId, channel::Sender<O>>>>,
120
121    _p: PhantomData<fn(I) -> O>,
122
123    crash: Option<WorkerCrashError>,
124}
125impl<I: IpcValue, O: IpcValue> Worker<I, O> {
126    /// Start a worker process implemented in the current executable.
127    ///
128    /// Note that the current process must call [`run_worker`] at startup to actually work.
129    /// You can use [`zng_env::on_process_start!`] to inject startup code.
130    pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
131        Self::start_impl(worker_name.into(), std::env::current_exe()?, &[], &[]).await
132    }
133
134    /// Start a worker process implemented in the current executable with custom env vars and args.
135    pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
136        Self::start_impl(worker_name.into(), std::env::current_exe()?, env_vars, args).await
137    }
138
139    /// Start a worker process implemented in another executable with custom env vars and args.
140    pub async fn start_other(
141        worker_name: impl Into<Txt>,
142        worker_exe: impl Into<PathBuf>,
143        env_vars: &[(&str, &str)],
144        args: &[&str],
145    ) -> std::io::Result<Self> {
146        Self::start_impl(worker_name.into(), worker_exe.into(), env_vars, args).await
147    }
148
149    async fn start_impl(worker_name: Txt, exe: PathBuf, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
150        let chan_sender = NamedIpcSender::<WorkerInit<I, O>>::new()?;
151
152        let mut worker = std::process::Command::new(dunce::canonicalize(exe)?);
153        for (key, value) in env_vars {
154            worker.env(key, value);
155        }
156        for arg in args {
157            worker.arg(arg);
158        }
159        worker
160            .env(WORKER_VERSION, crate::process::worker::VERSION)
161            .env(WORKER_SERVER, chan_sender.name())
162            .env(WORKER_NAME, worker_name)
163            .env("RUST_BACKTRACE", "full");
164
165        worker.stderr(Stdio::piped());
166
167        let mut worker = blocking::unblock(move || worker.spawn()).await?;
168
169        let timeout = match std::env::var(WORKER_TIMEOUT) {
170            Ok(t) if !t.is_empty() => match t.parse::<u64>() {
171                Ok(t) => t.max(1),
172                Err(e) => {
173                    tracing::error!("invalid {WORKER_TIMEOUT:?} value, {e}");
174                    10
175                }
176            },
177            _ => 10,
178        };
179
180        let (request_sender, mut response_receiver) = match Self::connect_worker(chan_sender, timeout).await {
181            Ok(r) => r,
182            Err(ce) => {
183                let cleanup = blocking::unblock(move || {
184                    worker.kill()?;
185                    worker.wait()
186                });
187                match cleanup.await {
188                    Ok(status) => {
189                        let code = status.code().unwrap_or(0);
190                        return Err(std::io::Error::new(
191                            std::io::ErrorKind::TimedOut,
192                            format!("worker process did not connect in {timeout}s\nworker exit code: {code}\nchannel error: {ce}"),
193                        ));
194                    }
195                    Err(e) => {
196                        return Err(std::io::Error::new(
197                            std::io::ErrorKind::TimedOut,
198                            format!("worker process did not connect in {timeout}s\ncannot kill worker process, {e}\nchannel error: {ce}"),
199                        ));
200                    }
201                }
202            }
203        };
204
205        let requests = Arc::new(Mutex::new(IdMap::<RequestId, channel::Sender<O>>::new()));
206        let receiver = std::thread::Builder::new()
207            .name("task-ipc-recv".into())
208            .stack_size(256 * 1024)
209            .spawn(clmv!(requests, || {
210                loop {
211                    match response_receiver.recv_blocking() {
212                        Ok((id, r)) => match requests.lock().remove(&id) {
213                            Some(s) => match r {
214                                Response::Out(r) => {
215                                    let _ = s.send_blocking(r);
216                                }
217                            },
218                            None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
219                        },
220                        Err(e) => match e {
221                            ChannelError::Disconnected { .. } => {
222                                requests.lock().clear();
223                                break;
224                            }
225                            e => {
226                                tracing::error!("worker response error, will shutdown, {e}");
227                                break;
228                            }
229                        },
230                    }
231                }
232            }))
233            .expect("failed to spawn thread");
234
235        let stderr_tap = StderrTap::new_blocking(worker.stderr.take().unwrap());
236
237        Ok(Self {
238            running: Some((receiver, worker, stderr_tap)),
239            sender: request_sender,
240            _p: PhantomData,
241            crash: None,
242            requests,
243        })
244    }
245    async fn connect_worker(
246        chan_sender: NamedIpcSender<WorkerInit<I, O>>,
247        timeout: u64,
248    ) -> Result<(IpcSender<(RequestId, Request<I>)>, IpcReceiver<(RequestId, Response<O>)>), ChannelError> {
249        let mut chan_sender = chan_sender.connect_deadline(timeout.secs()).await?;
250
251        let (request_sender, request_receiver) =
252            channel::ipc_unbounded::<(RequestId, Request<I>)>().map_err(ChannelError::disconnected_by)?;
253        let (response_sender, response_receiver) =
254            channel::ipc_unbounded::<(RequestId, Response<O>)>().map_err(ChannelError::disconnected_by)?;
255
256        chan_sender.send_blocking((request_receiver, response_sender))?;
257
258        Ok((request_sender, response_receiver))
259    }
260
261    /// Awaits current tasks and kills the worker process.
262    pub async fn shutdown(mut self) -> std::io::Result<()> {
263        if let Some((receiver, mut worker, _)) = self.running.take() {
264            while !self.requests.lock().is_empty() {
265                crate::deadline(100.ms()).await;
266            }
267            let r = blocking::unblock(move || {
268                worker.kill()?;
269                worker.wait()?;
270                Ok(())
271            })
272            .await;
273
274            match crate::with_deadline(blocking::unblock(move || receiver.join()), 1.secs()).await {
275                Ok(r) => {
276                    if let Err(p) = r {
277                        tracing::error!(
278                            "worker receiver thread exited panicked, {}",
279                            TaskPanicError::new(p).panic_str().unwrap_or("")
280                        );
281                    }
282                }
283                Err(_) => {
284                    // timeout
285                    if r.is_ok() {
286                        // after awaiting kill receiver thread should join fast because disconnect breaks loop
287                        panic!("worker receiver thread did not exit after worker process did");
288                    }
289                }
290            }
291            r
292        } else {
293            Ok(())
294        }
295    }
296
297    /// Run a task in a free worker process thread.
298    pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
299        self.run_request(Request::Run(input))
300    }
301
302    fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
303        if self.crash_error().is_some() {
304            return Box::pin(std::future::ready(Err(RunError::Disconnected)));
305        }
306
307        let id = RequestId::new_unique();
308        let (sx, rx) = channel::bounded(1);
309
310        let requests = self.requests.clone();
311        requests.lock().insert(id, sx);
312        let mut sender = self.sender.clone();
313        let send_r = blocking::unblock(move || sender.send_blocking((id, request)));
314
315        Box::pin(async move {
316            if let Err(e) = send_r.await {
317                tracing::error!("cannot send request, {e}");
318                requests.lock().remove(&id);
319                return Err(RunError::Other(Arc::new(e)));
320            }
321
322            match rx.recv().await {
323                Ok(r) => Ok(r),
324                Err(e) => match e {
325                    ChannelError::Disconnected { cause } => {
326                        let cause = match cause {
327                            Some(e) => format!(", {e}"),
328                            None => String::new(),
329                        };
330                        tracing::error!("cannot receive response, disconnected{cause}, more info in `crash_error`");
331                        requests.lock().remove(&id);
332                        Err(RunError::Disconnected)
333                    }
334                    _ => unreachable!(),
335                },
336            }
337        })
338    }
339
340    /// Reference the crash error.
341    ///
342    /// The worker cannot be used if this is set, run requests will immediately disconnect.
343    pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
344        // TODO(breaking) make this async
345        if let Some((t, _, _)) = &self.running
346            && t.is_finished()
347        {
348            let (t, mut p, stderr) = self.running.take().unwrap();
349
350            if let Err(e) = t.join() {
351                tracing::error!(
352                    "panic in worker receiver thread, {}",
353                    TaskPanicError::new(e).panic_str().unwrap_or("")
354                );
355            }
356
357            if let Err(e) = p.kill() {
358                tracing::error!("error killing worker process after receiver exit, {e}");
359            }
360
361            match p.wait() {
362                Ok(o) => {
363                    self.crash = Some(WorkerCrashError {
364                        status: o,
365                        stderr: stderr.into_txt_blocking(false),
366                    });
367                }
368                Err(e) => tracing::error!("error reading crashed worker output, {e}"),
369            }
370        }
371
372        self.crash.as_ref()
373    }
374}
375impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
376    fn drop(&mut self) {
377        if let Some((receiver, mut process, _)) = self.running.take() {
378            if !receiver.is_finished() {
379                tracing::error!("dropped worker without shutdown");
380            }
381            if let Err(e) = process.kill() {
382                tracing::error!("failed to kill worker process on drop, {e}");
383            }
384        }
385    }
386}
387
388/// If the process was started by a [`Worker`] runs the worker loop and never returns. If
389/// not started as worker does nothing.
390///
391/// The `handler` is called for each work request.
392pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
393where
394    I: IpcValue,
395    O: IpcValue,
396    F: Future<Output = O> + Send + 'static,
397{
398    let name = worker_name.into();
399    if let Some(server_name) = run_worker_server(&name) {
400        zng_env::init_process_name(zng_txt::formatx!("worker-process ({name}, {})", std::process::id()));
401
402        let mut chan_recv = IpcReceiver::<WorkerInit<I, O>>::connect(server_name)
403            .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
404
405        let (mut request_receiver, response_sender) = chan_recv
406            .recv_blocking()
407            .unwrap_or_else(|e| panic!("failed to connect initial channels, {e}"));
408
409        let handler = Arc::new(handler);
410
411        loop {
412            match request_receiver.recv_blocking() {
413                Ok((id, input)) => match input {
414                    Request::Run(r) => crate::spawn(async_clmv!(handler, mut response_sender, {
415                        let output = handler(RequestArgs { request: r }).await;
416                        let _ = response_sender.send_blocking((id, Response::Out(output)));
417                    })),
418                },
419                Err(e) => match e {
420                    ChannelError::Disconnected { cause } => {
421                        match cause {
422                            Some(e) => tracing::error!("exit worker, disconnected, {e}"),
423                            None => tracing::debug!("exit worker, disconnected"),
424                        }
425                        break;
426                    }
427                    ChannelError::Timeout => unreachable!(),
428                },
429            }
430        }
431
432        zng_env::exit(0);
433    }
434}
435fn run_worker_server(worker_name: &str) -> Option<String> {
436    if let Ok(w_name) = std::env::var(WORKER_NAME)
437        && let Ok(version) = std::env::var(WORKER_VERSION)
438        && let Ok(server_name) = std::env::var(WORKER_SERVER)
439    {
440        if w_name != worker_name {
441            return None;
442        }
443        if version != VERSION {
444            eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
445            zng_env::exit(i32::from_le_bytes(*b"vapi"));
446        }
447
448        Some(server_name)
449    } else {
450        None
451    }
452}
453
454/// Arguments for [`run_worker`].
455#[non_exhaustive]
456pub struct RequestArgs<I: IpcValue> {
457    /// The task request data.
458    pub request: I,
459}
460
461/// Worker run error.
462#[derive(Debug, Clone)]
463#[non_exhaustive]
464pub enum RunError {
465    /// Lost connection with the worker process.
466    ///
467    /// See [`Worker::crash_error`] for the error.
468    Disconnected,
469    /// Other error.
470    Other(Arc<dyn std::error::Error + Send + Sync>),
471}
472impl fmt::Display for RunError {
473    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474        match self {
475            RunError::Disconnected => write!(f, "worker process disconnected"),
476            RunError::Other(e) => write!(f, "run error, {e}"),
477        }
478    }
479}
480impl std::error::Error for RunError {}
481
482/// Info about a worker process crash.
483#[derive(Debug, Clone)]
484#[non_exhaustive]
485pub struct WorkerCrashError {
486    /// Worker process exit code.
487    pub status: std::process::ExitStatus,
488    /// Recorded stderr of the worker process.
489    pub stderr: Txt,
490}
491impl WorkerCrashError {
492    /// Gets if `stderr` does not contain any ANSI scape sequences.
493    pub fn is_stderr_plain(&self) -> bool {
494        !contains_ansi_csi(&self.stderr)
495    }
496
497    /// Get `stderr` without any ANSI escape sequences (CSI).
498    pub fn stderr_plain(&self) -> Txt {
499        if self.is_stderr_plain() {
500            self.stderr.clone()
501        } else {
502            remove_ansi_csi(&self.stderr)
503        }
504    }
505
506    /// Try parse `stderr` for the crash panic if exit code was `101`.
507    ///
508    /// Only reliably works if the panic fully printed correctly and was formatted by the panic
509    /// hook installed by `crash_handler` or by the display print of [`PanicInfo`].
510    pub fn find_panic(&self) -> Option<PanicInfo> {
511        if self.status.code() == Some(101) {
512            PanicInfo::find(&self.stderr)
513        } else {
514            None
515        }
516    }
517}
518impl fmt::Display for WorkerCrashError {
519    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
520        write!(f, "{:?}", self.status)
521    }
522}
523impl std::error::Error for WorkerCrashError {}
524
525#[derive(serde::Serialize, serde::Deserialize)]
526enum Request<I> {
527    Run(I),
528}
529
530#[derive(serde::Serialize, serde::Deserialize)]
531enum Response<O> {
532    Out(O),
533}
534
535type WorkerInit<I, O> = (
536    channel::IpcReceiver<(RequestId, Request<I>)>,
537    channel::IpcSender<(RequestId, Response<O>)>,
538);
539
540zng_unique_id::unique_id_64! {
541    #[derive(serde::Serialize, serde::Deserialize)]
542    struct RequestId;
543}