1use core::fmt;
88use std::{marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};
89
90use parking_lot::Mutex;
91use zng_clone_move::{async_clmv, clmv};
92use zng_txt::Txt;
93use zng_unique_id::IdMap;
94use zng_unit::TimeUnits as _;
95
96use crate::{
97 TaskPanicError,
98 channel::{self, ChannelError, IpcReceiver, IpcSender, IpcValue, NamedIpcSender},
99};
100
101const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
102const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
103const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
104
105const WORKER_TIMEOUT: &str = "ZNG_TASK_WORKER_TIMEOUT";
106
107pub const VERSION: &str = env!("CARGO_PKG_VERSION");
110
111pub struct Worker<I: IpcValue, O: IpcValue> {
113 running: Option<(std::thread::JoinHandle<()>, std::process::Child)>,
114
115 sender: IpcSender<(RequestId, Request<I>)>,
116 requests: Arc<Mutex<IdMap<RequestId, channel::Sender<O>>>>,
117
118 _p: PhantomData<fn(I) -> O>,
119
120 crash: Option<WorkerCrashError>,
121}
122impl<I: IpcValue, O: IpcValue> Worker<I, O> {
123 pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
128 Self::start_impl(worker_name.into(), std::env::current_exe()?, &[], &[]).await
129 }
130
131 pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
133 Self::start_impl(worker_name.into(), std::env::current_exe()?, env_vars, args).await
134 }
135
136 pub async fn start_other(
138 worker_name: impl Into<Txt>,
139 worker_exe: impl Into<PathBuf>,
140 env_vars: &[(&str, &str)],
141 args: &[&str],
142 ) -> std::io::Result<Self> {
143 Self::start_impl(worker_name.into(), worker_exe.into(), env_vars, args).await
144 }
145
146 async fn start_impl(worker_name: Txt, exe: PathBuf, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
147 let chan_sender = NamedIpcSender::<WorkerInit<I, O>>::new()?;
148
149 let mut worker = std::process::Command::new(dunce::canonicalize(exe)?);
150 for (key, value) in env_vars {
151 worker.env(key, value);
152 }
153 for arg in args {
154 worker.arg(arg);
155 }
156 worker
157 .env(WORKER_VERSION, crate::process::worker::VERSION)
158 .env(WORKER_SERVER, chan_sender.name())
159 .env(WORKER_NAME, worker_name)
160 .env("RUST_BACKTRACE", "full");
161 let mut worker = blocking::unblock(move || worker.spawn()).await?;
162
163 let timeout = match std::env::var(WORKER_TIMEOUT) {
164 Ok(t) if !t.is_empty() => match t.parse::<u64>() {
165 Ok(t) => t.max(1),
166 Err(e) => {
167 tracing::error!("invalid {WORKER_TIMEOUT:?} value, {e}");
168 10
169 }
170 },
171 _ => 10,
172 };
173
174 let (request_sender, mut response_receiver) = match Self::connect_worker(chan_sender, timeout).await {
175 Ok(r) => r,
176 Err(ce) => {
177 let cleanup = blocking::unblock(move || {
178 worker.kill()?;
179 worker.wait()
180 });
181 match cleanup.await {
182 Ok(status) => {
183 let code = status.code().unwrap_or(0);
184 return Err(std::io::Error::new(
185 std::io::ErrorKind::TimedOut,
186 format!("worker process did not connect in {timeout}s\nworker exit code: {code}\nchannel error: {ce}"),
187 ));
188 }
189 Err(e) => {
190 return Err(std::io::Error::new(
191 std::io::ErrorKind::TimedOut,
192 format!("worker process did not connect in {timeout}s\ncannot kill worker process, {e}\nchannel error: {ce}"),
193 ));
194 }
195 }
196 }
197 };
198
199 let requests = Arc::new(Mutex::new(IdMap::<RequestId, channel::Sender<O>>::new()));
200 let receiver = std::thread::Builder::new()
201 .name("task-ipc-recv".into())
202 .stack_size(256 * 1024)
203 .spawn(clmv!(requests, || {
204 loop {
205 match response_receiver.recv_blocking() {
206 Ok((id, r)) => match requests.lock().remove(&id) {
207 Some(s) => match r {
208 Response::Out(r) => {
209 let _ = s.send_blocking(r);
210 }
211 },
212 None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
213 },
214 Err(e) => match e {
215 ChannelError::Disconnected { .. } => {
216 requests.lock().clear();
217 break;
218 }
219 e => {
220 tracing::error!("worker response error, will shutdown, {e}");
221 break;
222 }
223 },
224 }
225 }
226 }))
227 .expect("failed to spawn thread");
228
229 Ok(Self {
230 running: Some((receiver, worker)),
231 sender: request_sender,
232 _p: PhantomData,
233 crash: None,
234 requests,
235 })
236 }
237 async fn connect_worker(
238 chan_sender: NamedIpcSender<WorkerInit<I, O>>,
239 timeout: u64,
240 ) -> Result<(IpcSender<(RequestId, Request<I>)>, IpcReceiver<(RequestId, Response<O>)>), ChannelError> {
241 let mut chan_sender = chan_sender.connect_deadline(timeout.secs()).await?;
242
243 let (request_sender, request_receiver) =
244 channel::ipc_unbounded::<(RequestId, Request<I>)>().map_err(ChannelError::disconnected_by)?;
245 let (response_sender, response_receiver) =
246 channel::ipc_unbounded::<(RequestId, Response<O>)>().map_err(ChannelError::disconnected_by)?;
247
248 chan_sender.send_blocking((request_receiver, response_sender))?;
249
250 Ok((request_sender, response_receiver))
251 }
252
253 pub async fn shutdown(mut self) -> std::io::Result<()> {
255 if let Some((receiver, mut process)) = self.running.take() {
256 while !self.requests.lock().is_empty() {
257 crate::deadline(100.ms()).await;
258 }
259 let r = blocking::unblock(move || process.kill()).await;
260
261 match crate::with_deadline(blocking::unblock(move || receiver.join()), 1.secs()).await {
262 Ok(r) => {
263 if let Err(p) = r {
264 tracing::error!(
265 "worker receiver thread exited panicked, {}",
266 TaskPanicError::new(p).panic_str().unwrap_or("")
267 );
268 }
269 }
270 Err(_) => {
271 if r.is_ok() {
273 panic!("worker receiver thread did not exit after worker process did");
275 }
276 }
277 }
278 r
279 } else {
280 Ok(())
281 }
282 }
283
284 pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
286 self.run_request(Request::Run(input))
287 }
288
289 fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
290 if self.crash_error().is_some() {
291 return Box::pin(std::future::ready(Err(RunError::Disconnected)));
292 }
293
294 let id = RequestId::new_unique();
295 let (sx, rx) = channel::bounded(1);
296
297 let requests = self.requests.clone();
298 requests.lock().insert(id, sx);
299 let mut sender = self.sender.clone();
300 let send_r = blocking::unblock(move || sender.send_blocking((id, request)));
301
302 Box::pin(async move {
303 if let Err(e) = send_r.await {
304 tracing::error!("cannot send request, {e}");
305 requests.lock().remove(&id);
306 return Err(RunError::Other(Arc::new(e)));
307 }
308
309 match rx.recv().await {
310 Ok(r) => Ok(r),
311 Err(e) => match e {
312 ChannelError::Disconnected { cause } => {
313 let cause = match cause {
314 Some(e) => format!(", {e}"),
315 None => String::new(),
316 };
317 tracing::error!("cannot receive response, disconnected{cause}, more info in `crash_error`");
318 requests.lock().remove(&id);
319 Err(RunError::Disconnected)
320 }
321 _ => unreachable!(),
322 },
323 }
324 })
325 }
326
327 pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
331 if let Some((t, _)) = &self.running
332 && t.is_finished()
333 {
334 let (t, mut p) = self.running.take().unwrap();
335
336 if let Err(e) = t.join() {
337 tracing::error!(
338 "panic in worker receiver thread, {}",
339 TaskPanicError::new(e).panic_str().unwrap_or("")
340 );
341 }
342
343 if let Err(e) = p.kill() {
344 tracing::error!("error killing worker process after receiver exit, {e}");
345 }
346
347 match p.wait() {
348 Ok(o) => {
349 self.crash = Some(WorkerCrashError { status: o });
350 }
351 Err(e) => tracing::error!("error reading crashed worker output, {e}"),
352 }
353 }
354
355 self.crash.as_ref()
356 }
357}
358impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
359 fn drop(&mut self) {
360 if let Some((receiver, mut process)) = self.running.take() {
361 if !receiver.is_finished() {
362 tracing::error!("dropped worker without shutdown");
363 }
364 if let Err(e) = process.kill() {
365 tracing::error!("failed to kill worker process on drop, {e}");
366 }
367 }
368 }
369}
370
371pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
376where
377 I: IpcValue,
378 O: IpcValue,
379 F: Future<Output = O> + Send + 'static,
380{
381 let name = worker_name.into();
382 if let Some(server_name) = run_worker_server(&name) {
383 zng_env::init_process_name(zng_txt::formatx!("worker-process ({name}, {})", std::process::id()));
384
385 let mut chan_recv = IpcReceiver::<WorkerInit<I, O>>::connect(server_name)
386 .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
387
388 let (mut request_receiver, response_sender) = chan_recv
389 .recv_blocking()
390 .unwrap_or_else(|e| panic!("failed to connect initial channels, {e}"));
391
392 let handler = Arc::new(handler);
393
394 loop {
395 match request_receiver.recv_blocking() {
396 Ok((id, input)) => match input {
397 Request::Run(r) => crate::spawn(async_clmv!(handler, mut response_sender, {
398 let output = handler(RequestArgs { request: r }).await;
399 let _ = response_sender.send_blocking((id, Response::Out(output)));
400 })),
401 },
402 Err(e) => match e {
403 ChannelError::Disconnected { cause } => {
404 match cause {
405 Some(e) => tracing::error!("exit worker, disconnected, {e}"),
406 None => tracing::debug!("exit worker, disconnected"),
407 }
408 break;
409 }
410 ChannelError::Timeout => unreachable!(),
411 },
412 }
413 }
414
415 zng_env::exit(0);
416 }
417}
418fn run_worker_server(worker_name: &str) -> Option<String> {
419 if let Ok(w_name) = std::env::var(WORKER_NAME)
420 && let Ok(version) = std::env::var(WORKER_VERSION)
421 && let Ok(server_name) = std::env::var(WORKER_SERVER)
422 {
423 if w_name != worker_name {
424 return None;
425 }
426 if version != VERSION {
427 eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
428 zng_env::exit(i32::from_le_bytes(*b"vapi"));
429 }
430
431 Some(server_name)
432 } else {
433 None
434 }
435}
436
437#[non_exhaustive]
439pub struct RequestArgs<I: IpcValue> {
440 pub request: I,
442}
443
444#[derive(Debug, Clone)]
446#[non_exhaustive]
447pub enum RunError {
448 Disconnected,
452 Other(Arc<dyn std::error::Error + Send + Sync>),
454}
455impl fmt::Display for RunError {
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 match self {
458 RunError::Disconnected => write!(f, "worker process disconnected"),
459 RunError::Other(e) => write!(f, "run error, {e}"),
460 }
461 }
462}
463impl std::error::Error for RunError {}
464
465#[derive(Debug, Clone)]
467#[non_exhaustive]
468pub struct WorkerCrashError {
469 pub status: std::process::ExitStatus,
471}
472impl fmt::Display for WorkerCrashError {
473 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
474 write!(f, "{:?}", self.status)
475 }
476}
477impl std::error::Error for WorkerCrashError {}
478
479#[derive(serde::Serialize, serde::Deserialize)]
480enum Request<I> {
481 Run(I),
482}
483
484#[derive(serde::Serialize, serde::Deserialize)]
485enum Response<O> {
486 Out(O),
487}
488
489type WorkerInit<I, O> = (
490 channel::IpcReceiver<(RequestId, Request<I>)>,
491 channel::IpcSender<(RequestId, Response<O>)>,
492);
493
494zng_unique_id::unique_id_64! {
495 #[derive(serde::Serialize, serde::Deserialize)]
496 struct RequestId;
497}