1use core::fmt;
88use std::{marker::PhantomData, path::PathBuf, pin::Pin, process::Stdio, sync::Arc};
89
90use parking_lot::Mutex;
91use zng_clone_move::{async_clmv, clmv};
92use zng_txt::Txt;
93use zng_unique_id::IdMap;
94use zng_unit::TimeUnits as _;
95
96use crate::{
97 TaskPanicError,
98 channel::{self, ChannelError, IpcReceiver, IpcSender, IpcValue, NamedIpcSender},
99 process::tap::{StderrTap, contains_ansi_csi, remove_ansi_csi},
100};
101
102use super::tap::PanicInfo;
103
104const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
105const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
106const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
107
108const WORKER_TIMEOUT: &str = "ZNG_TASK_WORKER_TIMEOUT";
109
110pub const VERSION: &str = env!("CARGO_PKG_VERSION");
113
114pub struct Worker<I: IpcValue, O: IpcValue> {
116 running: Option<(std::thread::JoinHandle<()>, std::process::Child, StderrTap)>,
117
118 sender: IpcSender<(RequestId, Request<I>)>,
119 requests: Arc<Mutex<IdMap<RequestId, channel::Sender<O>>>>,
120
121 _p: PhantomData<fn(I) -> O>,
122
123 crash: Option<WorkerCrashError>,
124}
125impl<I: IpcValue, O: IpcValue> Worker<I, O> {
126 pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
131 Self::start_impl(worker_name.into(), std::env::current_exe()?, &[], &[]).await
132 }
133
134 pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
136 Self::start_impl(worker_name.into(), std::env::current_exe()?, env_vars, args).await
137 }
138
139 pub async fn start_other(
141 worker_name: impl Into<Txt>,
142 worker_exe: impl Into<PathBuf>,
143 env_vars: &[(&str, &str)],
144 args: &[&str],
145 ) -> std::io::Result<Self> {
146 Self::start_impl(worker_name.into(), worker_exe.into(), env_vars, args).await
147 }
148
149 async fn start_impl(worker_name: Txt, exe: PathBuf, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
150 let chan_sender = NamedIpcSender::<WorkerInit<I, O>>::new()?;
151
152 let mut worker = std::process::Command::new(dunce::canonicalize(exe)?);
153 for (key, value) in env_vars {
154 worker.env(key, value);
155 }
156 for arg in args {
157 worker.arg(arg);
158 }
159 worker
160 .env(WORKER_VERSION, crate::process::worker::VERSION)
161 .env(WORKER_SERVER, chan_sender.name())
162 .env(WORKER_NAME, worker_name)
163 .env("RUST_BACKTRACE", "full");
164
165 worker.stderr(Stdio::piped());
166
167 let mut worker = blocking::unblock(move || worker.spawn()).await?;
168
169 let timeout = match std::env::var(WORKER_TIMEOUT) {
170 Ok(t) if !t.is_empty() => match t.parse::<u64>() {
171 Ok(t) => t.max(1),
172 Err(e) => {
173 tracing::error!("invalid {WORKER_TIMEOUT:?} value, {e}");
174 10
175 }
176 },
177 _ => 10,
178 };
179
180 let (request_sender, mut response_receiver) = match Self::connect_worker(chan_sender, timeout).await {
181 Ok(r) => r,
182 Err(ce) => {
183 let cleanup = blocking::unblock(move || {
184 worker.kill()?;
185 worker.wait()
186 });
187 match cleanup.await {
188 Ok(status) => {
189 let code = status.code().unwrap_or(0);
190 return Err(std::io::Error::new(
191 std::io::ErrorKind::TimedOut,
192 format!("worker process did not connect in {timeout}s\nworker exit code: {code}\nchannel error: {ce}"),
193 ));
194 }
195 Err(e) => {
196 return Err(std::io::Error::new(
197 std::io::ErrorKind::TimedOut,
198 format!("worker process did not connect in {timeout}s\ncannot kill worker process, {e}\nchannel error: {ce}"),
199 ));
200 }
201 }
202 }
203 };
204
205 let requests = Arc::new(Mutex::new(IdMap::<RequestId, channel::Sender<O>>::new()));
206 let receiver = std::thread::Builder::new()
207 .name("task-ipc-recv".into())
208 .stack_size(256 * 1024)
209 .spawn(clmv!(requests, || {
210 loop {
211 match response_receiver.recv_blocking() {
212 Ok((id, r)) => match requests.lock().remove(&id) {
213 Some(s) => match r {
214 Response::Out(r) => {
215 let _ = s.send_blocking(r);
216 }
217 },
218 None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
219 },
220 Err(e) => match e {
221 ChannelError::Disconnected { .. } => {
222 requests.lock().clear();
223 break;
224 }
225 e => {
226 tracing::error!("worker response error, will shutdown, {e}");
227 break;
228 }
229 },
230 }
231 }
232 }))
233 .expect("failed to spawn thread");
234
235 let stderr_tap = StderrTap::new_blocking(worker.stderr.take().unwrap());
236
237 Ok(Self {
238 running: Some((receiver, worker, stderr_tap)),
239 sender: request_sender,
240 _p: PhantomData,
241 crash: None,
242 requests,
243 })
244 }
245 async fn connect_worker(
246 chan_sender: NamedIpcSender<WorkerInit<I, O>>,
247 timeout: u64,
248 ) -> Result<(IpcSender<(RequestId, Request<I>)>, IpcReceiver<(RequestId, Response<O>)>), ChannelError> {
249 let mut chan_sender = chan_sender.connect_deadline(timeout.secs()).await?;
250
251 let (request_sender, request_receiver) =
252 channel::ipc_unbounded::<(RequestId, Request<I>)>().map_err(ChannelError::disconnected_by)?;
253 let (response_sender, response_receiver) =
254 channel::ipc_unbounded::<(RequestId, Response<O>)>().map_err(ChannelError::disconnected_by)?;
255
256 chan_sender.send_blocking((request_receiver, response_sender))?;
257
258 Ok((request_sender, response_receiver))
259 }
260
261 pub async fn shutdown(mut self) -> std::io::Result<()> {
263 if let Some((receiver, mut worker, _)) = self.running.take() {
264 while !self.requests.lock().is_empty() {
265 crate::deadline(100.ms()).await;
266 }
267 let r = blocking::unblock(move || {
268 worker.kill()?;
269 worker.wait()?;
270 Ok(())
271 })
272 .await;
273
274 match crate::with_deadline(blocking::unblock(move || receiver.join()), 1.secs()).await {
275 Ok(r) => {
276 if let Err(p) = r {
277 tracing::error!(
278 "worker receiver thread exited panicked, {}",
279 TaskPanicError::new(p).panic_str().unwrap_or("")
280 );
281 }
282 }
283 Err(_) => {
284 if r.is_ok() {
286 panic!("worker receiver thread did not exit after worker process did");
288 }
289 }
290 }
291 r
292 } else {
293 Ok(())
294 }
295 }
296
297 pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
299 self.run_request(Request::Run(input))
300 }
301
302 fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
303 if self.crash_error().is_some() {
304 return Box::pin(std::future::ready(Err(RunError::Disconnected)));
305 }
306
307 let id = RequestId::new_unique();
308 let (sx, rx) = channel::bounded(1);
309
310 let requests = self.requests.clone();
311 requests.lock().insert(id, sx);
312 let mut sender = self.sender.clone();
313 let send_r = blocking::unblock(move || sender.send_blocking((id, request)));
314
315 Box::pin(async move {
316 if let Err(e) = send_r.await {
317 tracing::error!("cannot send request, {e}");
318 requests.lock().remove(&id);
319 return Err(RunError::Other(Arc::new(e)));
320 }
321
322 match rx.recv().await {
323 Ok(r) => Ok(r),
324 Err(e) => match e {
325 ChannelError::Disconnected { cause } => {
326 let cause = match cause {
327 Some(e) => format!(", {e}"),
328 None => String::new(),
329 };
330 tracing::error!("cannot receive response, disconnected{cause}, more info in `crash_error`");
331 requests.lock().remove(&id);
332 Err(RunError::Disconnected)
333 }
334 _ => unreachable!(),
335 },
336 }
337 })
338 }
339
340 pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
344 if let Some((t, _, _)) = &self.running
346 && t.is_finished()
347 {
348 let (t, mut p, stderr) = self.running.take().unwrap();
349
350 if let Err(e) = t.join() {
351 tracing::error!(
352 "panic in worker receiver thread, {}",
353 TaskPanicError::new(e).panic_str().unwrap_or("")
354 );
355 }
356
357 if let Err(e) = p.kill() {
358 tracing::error!("error killing worker process after receiver exit, {e}");
359 }
360
361 match p.wait() {
362 Ok(o) => {
363 self.crash = Some(WorkerCrashError {
364 status: o,
365 stderr: stderr.into_txt_blocking(false),
366 });
367 }
368 Err(e) => tracing::error!("error reading crashed worker output, {e}"),
369 }
370 }
371
372 self.crash.as_ref()
373 }
374}
375impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
376 fn drop(&mut self) {
377 if let Some((receiver, mut process, _)) = self.running.take() {
378 if !receiver.is_finished() {
379 tracing::error!("dropped worker without shutdown");
380 }
381 if let Err(e) = process.kill() {
382 tracing::error!("failed to kill worker process on drop, {e}");
383 }
384 }
385 }
386}
387
388pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
393where
394 I: IpcValue,
395 O: IpcValue,
396 F: Future<Output = O> + Send + 'static,
397{
398 let name = worker_name.into();
399 if let Some(server_name) = run_worker_server(&name) {
400 zng_env::init_process_name(zng_txt::formatx!("worker-process ({name}, {})", std::process::id()));
401
402 let mut chan_recv = IpcReceiver::<WorkerInit<I, O>>::connect(server_name)
403 .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
404
405 let (mut request_receiver, response_sender) = chan_recv
406 .recv_blocking()
407 .unwrap_or_else(|e| panic!("failed to connect initial channels, {e}"));
408
409 let handler = Arc::new(handler);
410
411 loop {
412 match request_receiver.recv_blocking() {
413 Ok((id, input)) => match input {
414 Request::Run(r) => crate::spawn(async_clmv!(handler, mut response_sender, {
415 let output = handler(RequestArgs { request: r }).await;
416 let _ = response_sender.send_blocking((id, Response::Out(output)));
417 })),
418 },
419 Err(e) => match e {
420 ChannelError::Disconnected { cause } => {
421 match cause {
422 Some(e) => tracing::error!("exit worker, disconnected, {e}"),
423 None => tracing::debug!("exit worker, disconnected"),
424 }
425 break;
426 }
427 ChannelError::Timeout => unreachable!(),
428 },
429 }
430 }
431
432 zng_env::exit(0);
433 }
434}
435fn run_worker_server(worker_name: &str) -> Option<String> {
436 if let Ok(w_name) = std::env::var(WORKER_NAME)
437 && let Ok(version) = std::env::var(WORKER_VERSION)
438 && let Ok(server_name) = std::env::var(WORKER_SERVER)
439 {
440 if w_name != worker_name {
441 return None;
442 }
443 if version != VERSION {
444 eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
445 zng_env::exit(i32::from_le_bytes(*b"vapi"));
446 }
447
448 Some(server_name)
449 } else {
450 None
451 }
452}
453
454#[non_exhaustive]
456pub struct RequestArgs<I: IpcValue> {
457 pub request: I,
459}
460
461#[derive(Debug, Clone)]
463#[non_exhaustive]
464pub enum RunError {
465 Disconnected,
469 Other(Arc<dyn std::error::Error + Send + Sync>),
471}
472impl fmt::Display for RunError {
473 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474 match self {
475 RunError::Disconnected => write!(f, "worker process disconnected"),
476 RunError::Other(e) => write!(f, "run error, {e}"),
477 }
478 }
479}
480impl std::error::Error for RunError {}
481
482#[derive(Debug, Clone)]
484#[non_exhaustive]
485pub struct WorkerCrashError {
486 pub status: std::process::ExitStatus,
488 pub stderr: Txt,
490}
491impl WorkerCrashError {
492 pub fn is_stderr_plain(&self) -> bool {
494 !contains_ansi_csi(&self.stderr)
495 }
496
497 pub fn stderr_plain(&self) -> Txt {
499 if self.is_stderr_plain() {
500 self.stderr.clone()
501 } else {
502 remove_ansi_csi(&self.stderr)
503 }
504 }
505
506 pub fn find_panic(&self) -> Option<PanicInfo> {
511 if self.status.code() == Some(101) {
512 PanicInfo::find(&self.stderr)
513 } else {
514 None
515 }
516 }
517}
518impl fmt::Display for WorkerCrashError {
519 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
520 write!(f, "{:?}", self.status)
521 }
522}
523impl std::error::Error for WorkerCrashError {}
524
525#[derive(serde::Serialize, serde::Deserialize)]
526enum Request<I> {
527 Run(I),
528}
529
530#[derive(serde::Serialize, serde::Deserialize)]
531enum Response<O> {
532 Out(O),
533}
534
535type WorkerInit<I, O> = (
536 channel::IpcReceiver<(RequestId, Request<I>)>,
537 channel::IpcSender<(RequestId, Response<O>)>,
538);
539
540zng_unique_id::unique_id_64! {
541 #[derive(serde::Serialize, serde::Deserialize)]
542 struct RequestId;
543}