1#![cfg(ipc)]
2
3use core::fmt;
89use std::{marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};
90
91use parking_lot::Mutex;
92use zng_clone_move::{async_clmv, clmv};
93use zng_txt::{ToTxt, Txt};
94use zng_unique_id::IdMap;
95use zng_unit::TimeUnits as _;
96
97#[doc(no_inline)]
98pub use ipc_channel::ipc::{IpcBytesReceiver, IpcBytesSender, IpcReceiver, IpcSender, bytes_channel};
99
100use crate::TaskPanicError;
101
102#[diagnostic::on_unimplemented(note = "`IpcValue` is implemented for all `T: Debug + Serialize + Deserialize + Send + 'static`")]
115pub trait IpcValue: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static {}
116
117impl<T: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static> IpcValue for T {}
118
119const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
120const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
121const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
122
123const WORKER_TIMEOUT: &str = "ZNG_TASK_WORKER_TIMEOUT";
124
125pub const VERSION: &str = env!("CARGO_PKG_VERSION");
128
129pub struct Worker<I: IpcValue, O: IpcValue> {
131 running: Option<(std::thread::JoinHandle<()>, duct::Handle)>,
132
133 sender: ipc_channel::ipc::IpcSender<(RequestId, Request<I>)>,
134 requests: Arc<Mutex<IdMap<RequestId, flume::Sender<O>>>>,
135
136 _p: PhantomData<fn(I) -> O>,
137
138 crash: Option<WorkerCrashError>,
139}
140impl<I: IpcValue, O: IpcValue> Worker<I, O> {
141 pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
146 Self::start_impl(worker_name.into(), duct::cmd!(dunce::canonicalize(std::env::current_exe()?)?)).await
147 }
148
149 pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
151 let mut worker = duct::cmd(dunce::canonicalize(std::env::current_exe()?)?, args);
152 for (name, value) in env_vars {
153 worker = worker.env(name, value);
154 }
155 Self::start_impl(worker_name.into(), worker).await
156 }
157
158 pub async fn start_other(
160 worker_name: impl Into<Txt>,
161 worker_exe: impl Into<PathBuf>,
162 env_vars: &[(&str, &str)],
163 args: &[&str],
164 ) -> std::io::Result<Self> {
165 let mut worker = duct::cmd(worker_exe.into(), args);
166 for (name, value) in env_vars {
167 worker = worker.env(name, value);
168 }
169 Self::start_impl(worker_name.into(), worker).await
170 }
171
172 pub async fn start_duct(worker_name: impl Into<Txt>, worker: duct::Expression) -> std::io::Result<Self> {
180 Self::start_impl(worker_name.into(), worker).await
181 }
182
183 async fn start_impl(worker_name: Txt, worker: duct::Expression) -> std::io::Result<Self> {
184 let (server, name) = ipc_channel::ipc::IpcOneShotServer::<WorkerInit<I, O>>::new()?;
185
186 let worker = worker
187 .env(WORKER_VERSION, crate::ipc::VERSION)
188 .env(WORKER_SERVER, name)
189 .env(WORKER_NAME, worker_name)
190 .env("RUST_BACKTRACE", "full")
191 .stdin_null()
192 .stdout_capture()
193 .stderr_capture()
194 .unchecked();
195
196 let process = crate::wait(move || worker.start()).await?;
197
198 let timeout = match std::env::var(WORKER_TIMEOUT) {
199 Ok(t) if !t.is_empty() => match t.parse::<u64>() {
200 Ok(t) => t.max(1),
201 Err(e) => {
202 tracing::error!("invalid {WORKER_TIMEOUT:?} value, {e}");
203 10
204 }
205 },
206 _ => 10,
207 };
208
209 let r = crate::with_deadline(crate::wait(move || server.accept()), timeout.secs()).await;
210
211 let (_, (req_sender, chan_sender)) = match r {
212 Ok(r) => match r {
213 Ok(r) => r,
214 Err(e) => return Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, e)),
215 },
216 Err(_) => match process.kill() {
217 Ok(()) => {
218 let output = process.wait().unwrap();
219 let stdout = String::from_utf8_lossy(&output.stdout);
220 let stderr = String::from_utf8_lossy(&output.stderr);
221 let code = output.status.code().unwrap_or(0);
222 return Err(std::io::Error::new(
223 std::io::ErrorKind::TimedOut,
224 format!(
225 "worker process did not connect in {timeout}s\nworker exit code: {code}\n--worker stdout--\n{stdout}\n--worker stderr--\n{stderr}"
226 ),
227 ));
228 }
229 Err(e) => {
230 return Err(std::io::Error::new(
231 std::io::ErrorKind::TimedOut,
232 format!("worker process did not connect in {timeout}s\ncannot be kill worker process, {e}"),
233 ));
234 }
235 },
236 };
237
238 let (rsp_sender, rsp_recv) = ipc_channel::ipc::channel()?;
239 crate::wait(move || chan_sender.send(rsp_sender)).await.unwrap();
240
241 let requests = Arc::new(Mutex::new(IdMap::<RequestId, flume::Sender<O>>::new()));
242 let receiver = std::thread::Builder::new()
243 .name("task-ipc-recv".into())
244 .stack_size(256 * 1024)
245 .spawn(clmv!(requests, || {
246 loop {
247 match rsp_recv.recv() {
248 Ok((id, r)) => match requests.lock().remove(&id) {
249 Some(s) => match r {
250 Response::Out(r) => {
251 let _ = s.send(r);
252 }
253 },
254 None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
255 },
256 Err(e) => match e {
257 ipc_channel::ipc::IpcError::Disconnected => {
258 requests.lock().clear();
259 break;
260 }
261 ipc_channel::ipc::IpcError::Bincode(e) => {
262 tracing::error!("worker response error, will shutdown, {e}");
263 break;
264 }
265 ipc_channel::ipc::IpcError::Io(e) => {
266 tracing::error!("worker response io error, will shutdown, {e}");
267 break;
268 }
269 },
270 }
271 }
272 }))
273 .expect("failed to spawn thread");
274
275 Ok(Self {
276 running: Some((receiver, process)),
277 sender: req_sender,
278 _p: PhantomData,
279 crash: None,
280 requests,
281 })
282 }
283
284 pub async fn shutdown(mut self) -> std::io::Result<()> {
286 if let Some((receiver, process)) = self.running.take() {
287 while !self.requests.lock().is_empty() {
288 crate::deadline(100.ms()).await;
289 }
290 let r = crate::wait(move || process.kill()).await;
291
292 match crate::with_deadline(crate::wait(move || receiver.join()), 1.secs()).await {
293 Ok(r) => {
294 if let Err(p) = r {
295 tracing::error!(
296 "worker receiver thread exited panicked, {}",
297 TaskPanicError::new(p).panic_str().unwrap_or("")
298 );
299 }
300 }
301 Err(_) => {
302 if r.is_ok() {
304 panic!("worker receiver thread did not exit after worker process did");
306 }
307 }
308 }
309 r
310 } else {
311 Ok(())
312 }
313 }
314
315 pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
317 self.run_request(Request::Run(input))
318 }
319
320 fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
321 if self.crash_error().is_some() {
322 return Box::pin(std::future::ready(Err(RunError::Disconnected)));
323 }
324
325 let id = RequestId::new_unique();
326 let (sx, rx) = flume::bounded(1);
327
328 let requests = self.requests.clone();
329 requests.lock().insert(id, sx);
330 let sender = self.sender.clone();
331 let send_r = crate::wait(move || sender.send((id, request)));
332
333 Box::pin(async move {
334 if let Err(e) = send_r.await {
335 requests.lock().remove(&id);
336 return Err(RunError::Other(Arc::new(e)));
337 }
338
339 match rx.recv_async().await {
340 Ok(r) => Ok(r),
341 Err(e) => match e {
342 flume::RecvError::Disconnected => {
343 requests.lock().remove(&id);
344 Err(RunError::Disconnected)
345 }
346 },
347 }
348 })
349 }
350
351 pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
355 if let Some((t, _)) = &self.running
356 && t.is_finished()
357 {
358 let (t, p) = self.running.take().unwrap();
359
360 if let Err(e) = t.join() {
361 tracing::error!(
362 "panic in worker receiver thread, {}",
363 TaskPanicError::new(e).panic_str().unwrap_or("")
364 );
365 }
366
367 if let Err(e) = p.kill() {
368 tracing::error!("error killing worker process after receiver exit, {e}");
369 }
370
371 match p.into_output() {
372 Ok(o) => {
373 self.crash = Some(WorkerCrashError {
374 status: o.status,
375 stdout: String::from_utf8_lossy(&o.stdout[..]).as_ref().to_txt(),
376 stderr: String::from_utf8_lossy(&o.stderr[..]).as_ref().to_txt(),
377 });
378 }
379 Err(e) => tracing::error!("error reading crashed worker output, {e}"),
380 }
381 }
382
383 self.crash.as_ref()
384 }
385}
386impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
387 fn drop(&mut self) {
388 if let Some((receiver, process)) = self.running.take() {
389 if !receiver.is_finished() {
390 tracing::error!("dropped worker without shutdown");
391 }
392 if let Err(e) = process.kill() {
393 tracing::error!("failed to kill worker process on drop, {e}");
394 }
395 }
396 }
397}
398
399pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
404where
405 I: IpcValue,
406 O: IpcValue,
407 F: Future<Output = O> + Send + Sync + 'static,
408{
409 let name = worker_name.into();
410 zng_env::init_process_name(zng_txt::formatx!("worker-process ({name}, {})", std::process::id()));
411 if let Some(server_name) = run_worker_server(&name) {
412 let app_init_sender = IpcSender::<WorkerInit<I, O>>::connect(server_name)
413 .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
414
415 let (req_sender, req_recv) = ipc_channel::ipc::channel().unwrap();
416 let (chan_sender, chan_recv) = ipc_channel::ipc::channel().unwrap();
417
418 app_init_sender.send((req_sender, chan_sender)).unwrap();
419 let rsp_sender = chan_recv.recv().unwrap();
420 let handler = Arc::new(handler);
421
422 loop {
423 match req_recv.recv() {
424 Ok((id, input)) => match input {
425 Request::Run(r) => crate::spawn(async_clmv!(handler, rsp_sender, {
426 let output = handler(RequestArgs { request: r }).await;
427 let _ = rsp_sender.send((id, Response::Out(output)));
428 })),
429 },
430 Err(e) => match e {
431 ipc_channel::ipc::IpcError::Bincode(e) => {
432 eprintln!("worker '{name}' request error, {e}")
433 }
434 ipc_channel::ipc::IpcError::Io(e) => panic!("worker '{name}' request io error, {e}"),
435 ipc_channel::ipc::IpcError::Disconnected => break,
436 },
437 }
438 }
439
440 zng_env::exit(0);
441 }
442}
443fn run_worker_server(worker_name: &str) -> Option<String> {
444 if let Ok(w_name) = std::env::var(WORKER_NAME)
445 && let Ok(version) = std::env::var(WORKER_VERSION)
446 && let Ok(server_name) = std::env::var(WORKER_SERVER)
447 {
448 if w_name != worker_name {
449 return None;
450 }
451 if version != VERSION {
452 eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
453 zng_env::exit(i32::from_le_bytes(*b"vapi"));
454 }
455
456 Some(server_name)
457 } else {
458 None
459 }
460}
461
462#[non_exhaustive]
464pub struct RequestArgs<I: IpcValue> {
465 pub request: I,
467}
468
469#[derive(Debug, Clone)]
471#[non_exhaustive]
472pub enum RunError {
473 Disconnected,
477 Other(Arc<dyn std::error::Error + Send + Sync>),
479}
480impl fmt::Display for RunError {
481 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
482 match self {
483 RunError::Disconnected => write!(f, "worker process disconnected"),
484 RunError::Other(e) => write!(f, "run error, {e}"),
485 }
486 }
487}
488impl std::error::Error for RunError {}
489
490#[derive(Debug, Clone)]
492#[non_exhaustive]
493pub struct WorkerCrashError {
494 pub status: std::process::ExitStatus,
496 pub stdout: Txt,
498 pub stderr: Txt,
500}
501impl fmt::Display for WorkerCrashError {
502 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
503 write!(f, "{:?}\nSTDOUT:\n{}\nSTDERR:\n{}", self.status, &self.stdout, &self.stderr)
504 }
505}
506impl std::error::Error for WorkerCrashError {}
507
508#[derive(serde::Serialize, serde::Deserialize)]
509enum Request<I> {
510 Run(I),
511}
512
513#[derive(serde::Serialize, serde::Deserialize)]
514enum Response<O> {
515 Out(O),
516}
517
518type WorkerInit<I, O> = (IpcSender<(RequestId, Request<I>)>, IpcSender<IpcSender<(RequestId, Response<O>)>>);
527
528zng_unique_id::unique_id_64! {
529 #[derive(serde::Serialize, serde::Deserialize)]
530 struct RequestId;
531}