1#![cfg(ipc)]
2
3use core::fmt;
83use std::{marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};
84
85use parking_lot::Mutex;
86use zng_clone_move::{async_clmv, clmv};
87use zng_txt::{ToTxt, Txt};
88use zng_unique_id::IdMap;
89use zng_unit::TimeUnits as _;
90
91#[doc(no_inline)]
92pub use ipc_channel::ipc::{IpcBytesReceiver, IpcBytesSender, IpcReceiver, IpcSender, bytes_channel};
93
94#[diagnostic::on_unimplemented(note = "`IpcValue` is implemented for all `T: Debug + Serialize + Deserialize + Send + 'static`")]
107pub trait IpcValue: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static {}
108
109impl<T: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static> IpcValue for T {}
110
111const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
112const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
113const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
114
115pub const VERSION: &str = env!("CARGO_PKG_VERSION");
118
119pub struct Worker<I: IpcValue, O: IpcValue> {
121 running: Option<(std::thread::JoinHandle<()>, duct::Handle)>,
122
123 sender: ipc_channel::ipc::IpcSender<(RequestId, Request<I>)>,
124 requests: Arc<Mutex<IdMap<RequestId, flume::Sender<O>>>>,
125
126 _p: PhantomData<fn(I) -> O>,
127
128 crash: Option<WorkerCrashError>,
129}
130impl<I: IpcValue, O: IpcValue> Worker<I, O> {
131 pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
136 Self::start_impl(worker_name.into(), duct::cmd!(dunce::canonicalize(std::env::current_exe()?)?)).await
137 }
138
139 pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
141 let mut worker = duct::cmd(dunce::canonicalize(std::env::current_exe()?)?, args);
142 for (name, value) in env_vars {
143 worker = worker.env(name, value);
144 }
145 Self::start_impl(worker_name.into(), worker).await
146 }
147
148 pub async fn start_other(
150 worker_name: impl Into<Txt>,
151 worker_exe: impl Into<PathBuf>,
152 env_vars: &[(&str, &str)],
153 args: &[&str],
154 ) -> std::io::Result<Self> {
155 let mut worker = duct::cmd(worker_exe.into(), args);
156 for (name, value) in env_vars {
157 worker = worker.env(name, value);
158 }
159 Self::start_impl(worker_name.into(), worker).await
160 }
161
162 pub async fn start_duct(worker_name: impl Into<Txt>, worker: duct::Expression) -> std::io::Result<Self> {
170 Self::start_impl(worker_name.into(), worker).await
171 }
172
173 async fn start_impl(worker_name: Txt, worker: duct::Expression) -> std::io::Result<Self> {
174 let (server, name) = ipc_channel::ipc::IpcOneShotServer::<WorkerInit<I, O>>::new()?;
175
176 let worker = worker
177 .env(WORKER_VERSION, crate::ipc::VERSION)
178 .env(WORKER_SERVER, name)
179 .env(WORKER_NAME, worker_name)
180 .env("RUST_BACKTRACE", "full")
181 .stdin_null()
182 .stdout_capture()
183 .stderr_capture()
184 .unchecked();
185
186 let process = crate::wait(move || worker.start()).await?;
187
188 let r = crate::with_deadline(crate::wait(move || server.accept()), 10.secs()).await;
189
190 let (_, (req_sender, chan_sender)) = match r {
191 Ok(r) => match r {
192 Ok(r) => r,
193 Err(e) => return Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, e)),
194 },
195 Err(_) => match process.kill() {
196 Ok(()) => {
197 let output = process.wait().unwrap();
198 let stdout = String::from_utf8_lossy(&output.stdout);
199 let stderr = String::from_utf8_lossy(&output.stderr);
200 let code = output.status.code().unwrap_or(0);
201 return Err(std::io::Error::new(
202 std::io::ErrorKind::TimedOut,
203 format!(
204 "worker process did not connect in 10 seconds\nworker exit code: {code}\n--worker stdout--\n{stdout}\n--worker stderr--\n{stderr}"
205 ),
206 ));
207 }
208 Err(e) => {
209 return Err(std::io::Error::new(
210 std::io::ErrorKind::TimedOut,
211 format!("worker process did not connect in 10s\ncannot be kill worker process, {e}"),
212 ));
213 }
214 },
215 };
216
217 let (rsp_sender, rsp_recv) = ipc_channel::ipc::channel()?;
218 crate::wait(move || chan_sender.send(rsp_sender)).await.unwrap();
219
220 let requests = Arc::new(Mutex::new(IdMap::<RequestId, flume::Sender<O>>::new()));
221 let receiver = std::thread::spawn(clmv!(requests, || {
222 loop {
223 match rsp_recv.recv() {
224 Ok((id, r)) => match requests.lock().remove(&id) {
225 Some(s) => match r {
226 Response::Out(r) => {
227 let _ = s.send(r);
228 }
229 },
230 None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
231 },
232 Err(e) => match e {
233 ipc_channel::ipc::IpcError::Disconnected => {
234 requests.lock().clear();
235 break;
236 }
237 ipc_channel::ipc::IpcError::Bincode(e) => {
238 tracing::error!("worker response error, will shutdown, {e}");
239 break;
240 }
241 ipc_channel::ipc::IpcError::Io(e) => {
242 tracing::error!("worker response io error, will shutdown, {e}");
243 break;
244 }
245 },
246 }
247 }
248 }));
249
250 Ok(Self {
251 running: Some((receiver, process)),
252 sender: req_sender,
253 _p: PhantomData,
254 crash: None,
255 requests,
256 })
257 }
258
259 pub async fn shutdown(mut self) -> std::io::Result<()> {
261 if let Some((receiver, process)) = self.running.take() {
262 while !self.requests.lock().is_empty() {
263 crate::deadline(100.ms()).await;
264 }
265 let r = crate::wait(move || process.kill()).await;
266
267 match crate::with_deadline(crate::wait(move || receiver.join()), 1.secs()).await {
268 Ok(r) => {
269 if let Err(p) = r {
270 tracing::error!("worker receiver thread exited panicked, {}", crate::crate_util::panic_str(&p));
271 }
272 }
273 Err(_) => {
274 if r.is_ok() {
276 panic!("worker receiver thread did not exit after worker process did");
278 }
279 }
280 }
281 r
282 } else {
283 Ok(())
284 }
285 }
286
287 pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
289 self.run_request(Request::Run(input))
290 }
291
292 fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
293 if self.crash_error().is_some() {
294 return Box::pin(std::future::ready(Err(RunError::Disconnected)));
295 }
296
297 let id = RequestId::new_unique();
298 let (sx, rx) = flume::bounded(1);
299
300 let requests = self.requests.clone();
301 requests.lock().insert(id, sx);
302 let sender = self.sender.clone();
303 let send_r = crate::wait(move || sender.send((id, request)));
304
305 Box::pin(async move {
306 if let Err(e) = send_r.await {
307 requests.lock().remove(&id);
308 return Err(RunError::Other(Arc::new(e)));
309 }
310
311 match rx.recv_async().await {
312 Ok(r) => Ok(r),
313 Err(e) => match e {
314 flume::RecvError::Disconnected => {
315 requests.lock().remove(&id);
316 Err(RunError::Disconnected)
317 }
318 },
319 }
320 })
321 }
322
323 pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
327 if let Some((t, _)) = &self.running
328 && t.is_finished()
329 {
330 let (t, p) = self.running.take().unwrap();
331
332 if let Err(e) = t.join() {
333 tracing::error!("panic in worker receiver thread, {}", crate::crate_util::panic_str(&e));
334 }
335
336 if let Err(e) = p.kill() {
337 tracing::error!("error killing worker process after receiver exit, {e}");
338 }
339
340 match p.into_output() {
341 Ok(o) => {
342 self.crash = Some(WorkerCrashError {
343 status: o.status,
344 stdout: String::from_utf8_lossy(&o.stdout[..]).as_ref().to_txt(),
345 stderr: String::from_utf8_lossy(&o.stderr[..]).as_ref().to_txt(),
346 });
347 }
348 Err(e) => tracing::error!("error reading crashed worker output, {e}"),
349 }
350 }
351
352 self.crash.as_ref()
353 }
354}
355impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
356 fn drop(&mut self) {
357 if let Some((receiver, process)) = self.running.take() {
358 if !receiver.is_finished() {
359 tracing::error!("dropped worker without shutdown");
360 }
361 if let Err(e) = process.kill() {
362 tracing::error!("failed to kill worker process on drop, {e}");
363 }
364 }
365 }
366}
367
368pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
373where
374 I: IpcValue,
375 O: IpcValue,
376 F: Future<Output = O> + Send + Sync + 'static,
377{
378 let name = worker_name.into();
379 zng_env::init_process_name(zng_txt::formatx!("worker-process ({name}, {})", std::process::id()));
380 if let Some(server_name) = run_worker_server(&name) {
381 let app_init_sender = IpcSender::<WorkerInit<I, O>>::connect(server_name)
382 .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
383
384 let (req_sender, req_recv) = ipc_channel::ipc::channel().unwrap();
385 let (chan_sender, chan_recv) = ipc_channel::ipc::channel().unwrap();
386
387 app_init_sender.send((req_sender, chan_sender)).unwrap();
388 let rsp_sender = chan_recv.recv().unwrap();
389 let handler = Arc::new(handler);
390
391 loop {
392 match req_recv.recv() {
393 Ok((id, input)) => match input {
394 Request::Run(r) => crate::spawn(async_clmv!(handler, rsp_sender, {
395 let output = handler(RequestArgs { request: r }).await;
396 let _ = rsp_sender.send((id, Response::Out(output)));
397 })),
398 },
399 Err(e) => match e {
400 ipc_channel::ipc::IpcError::Bincode(e) => {
401 eprintln!("worker '{name}' request error, {e}")
402 }
403 ipc_channel::ipc::IpcError::Io(e) => panic!("worker '{name}' request io error, {e}"),
404 ipc_channel::ipc::IpcError::Disconnected => break,
405 },
406 }
407 }
408
409 zng_env::exit(0);
410 }
411}
412fn run_worker_server(worker_name: &str) -> Option<String> {
413 if let Ok(w_name) = std::env::var(WORKER_NAME)
414 && let Ok(version) = std::env::var(WORKER_VERSION)
415 && let Ok(server_name) = std::env::var(WORKER_SERVER)
416 {
417 if w_name != worker_name {
418 return None;
419 }
420 if version != VERSION {
421 eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
422 zng_env::exit(i32::from_le_bytes(*b"vapi"));
423 }
424
425 Some(server_name)
426 } else {
427 None
428 }
429}
430
431#[non_exhaustive]
433pub struct RequestArgs<I: IpcValue> {
434 pub request: I,
436}
437
438#[derive(Debug, Clone)]
440#[non_exhaustive]
441pub enum RunError {
442 Disconnected,
446 Other(Arc<dyn std::error::Error + Send + Sync>),
448}
449impl fmt::Display for RunError {
450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 match self {
452 RunError::Disconnected => write!(f, "worker process disconnected"),
453 RunError::Other(e) => write!(f, "run error, {e}"),
454 }
455 }
456}
457impl std::error::Error for RunError {}
458
459#[derive(Debug, Clone)]
461#[non_exhaustive]
462pub struct WorkerCrashError {
463 pub status: std::process::ExitStatus,
465 pub stdout: Txt,
467 pub stderr: Txt,
469}
470impl fmt::Display for WorkerCrashError {
471 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
472 write!(f, "{:?}\nSTDOUT:\n{}\nSTDERR:\n{}", self.status, &self.stdout, &self.stderr)
473 }
474}
475impl std::error::Error for WorkerCrashError {}
476
477#[derive(serde::Serialize, serde::Deserialize)]
478enum Request<I> {
479 Run(I),
480}
481
482#[derive(serde::Serialize, serde::Deserialize)]
483enum Response<O> {
484 Out(O),
485}
486
487type WorkerInit<I, O> = (IpcSender<(RequestId, Request<I>)>, IpcSender<IpcSender<(RequestId, Response<O>)>>);
496
497zng_unique_id::unique_id_64! {
498 #[derive(serde::Serialize, serde::Deserialize)]
499 struct RequestId;
500}