1#![cfg(ipc)]
2
3use core::fmt;
78use std::{marker::PhantomData, path::PathBuf, pin::Pin, sync::Arc};
79
80use parking_lot::Mutex;
81use zng_clone_move::{async_clmv, clmv};
82use zng_txt::{ToTxt, Txt};
83use zng_unique_id::IdMap;
84use zng_unit::TimeUnits as _;
85
86#[doc(no_inline)]
87pub use ipc_channel::ipc::{IpcBytesReceiver, IpcBytesSender, IpcReceiver, IpcSender, bytes_channel};
88
89#[diagnostic::on_unimplemented(note = "`IpcValue` is implemented for all `T: Debug + Serialize + Deserialize + Send + 'static`")]
102pub trait IpcValue: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static {}
103
104impl<T: fmt::Debug + serde::Serialize + for<'d> serde::de::Deserialize<'d> + Send + 'static> IpcValue for T {}
105
106const WORKER_VERSION: &str = "ZNG_TASK_IPC_WORKER_VERSION";
107const WORKER_SERVER: &str = "ZNG_TASK_IPC_WORKER_SERVER";
108const WORKER_NAME: &str = "ZNG_TASK_IPC_WORKER_NAME";
109
110pub const VERSION: &str = env!("CARGO_PKG_VERSION");
113
114pub struct Worker<I: IpcValue, O: IpcValue> {
116 running: Option<(std::thread::JoinHandle<()>, duct::Handle)>,
117
118 sender: ipc_channel::ipc::IpcSender<(RequestId, Request<I>)>,
119 requests: Arc<Mutex<IdMap<RequestId, flume::Sender<O>>>>,
120
121 _p: PhantomData<fn(I) -> O>,
122
123 crash: Option<WorkerCrashError>,
124}
125impl<I: IpcValue, O: IpcValue> Worker<I, O> {
126 pub async fn start(worker_name: impl Into<Txt>) -> std::io::Result<Self> {
131 Self::start_impl(worker_name.into(), duct::cmd!(dunce::canonicalize(std::env::current_exe()?)?)).await
132 }
133
134 pub async fn start_with(worker_name: impl Into<Txt>, env_vars: &[(&str, &str)], args: &[&str]) -> std::io::Result<Self> {
136 let mut worker = duct::cmd(dunce::canonicalize(std::env::current_exe()?)?, args);
137 for (name, value) in env_vars {
138 worker = worker.env(name, value);
139 }
140 Self::start_impl(worker_name.into(), worker).await
141 }
142
143 pub async fn start_other(
145 worker_name: impl Into<Txt>,
146 worker_exe: impl Into<PathBuf>,
147 env_vars: &[(&str, &str)],
148 args: &[&str],
149 ) -> std::io::Result<Self> {
150 let mut worker = duct::cmd(worker_exe.into(), args);
151 for (name, value) in env_vars {
152 worker = worker.env(name, value);
153 }
154 Self::start_impl(worker_name.into(), worker).await
155 }
156
157 pub async fn start_duct(worker_name: impl Into<Txt>, worker: duct::Expression) -> std::io::Result<Self> {
164 Self::start_impl(worker_name.into(), worker).await
165 }
166
167 async fn start_impl(worker_name: Txt, worker: duct::Expression) -> std::io::Result<Self> {
168 let (server, name) = ipc_channel::ipc::IpcOneShotServer::<WorkerInit<I, O>>::new()?;
169
170 let worker = worker
171 .env(WORKER_VERSION, crate::ipc::VERSION)
172 .env(WORKER_SERVER, name)
173 .env(WORKER_NAME, worker_name)
174 .env("RUST_BACKTRACE", "full")
175 .stdin_null()
176 .stdout_capture()
177 .stderr_capture()
178 .unchecked();
179
180 let process = crate::wait(move || worker.start()).await?;
181
182 let r = crate::with_deadline(crate::wait(move || server.accept()), 10.secs()).await;
183
184 let (_, (req_sender, chan_sender)) = match r {
185 Ok(r) => match r {
186 Ok(r) => r,
187 Err(e) => return Err(std::io::Error::new(std::io::ErrorKind::ConnectionRefused, e)),
188 },
189 Err(_) => match process.kill() {
190 Ok(()) => {
191 let output = process.wait().unwrap();
192 let stdout = String::from_utf8_lossy(&output.stdout);
193 let stderr = String::from_utf8_lossy(&output.stderr);
194 let code = output.status.code().unwrap_or(0);
195 return Err(std::io::Error::new(
196 std::io::ErrorKind::TimedOut,
197 format!(
198 "worker process did not connect in 10 seconds\nworker exit code: {code}\n--worker stdout--\n{stdout}\n--worker stderr--\n{stderr}"
199 ),
200 ));
201 }
202 Err(e) => {
203 return Err(std::io::Error::new(
204 std::io::ErrorKind::TimedOut,
205 format!("worker process did not connect in 10s\ncannot be kill worker process, {e}"),
206 ));
207 }
208 },
209 };
210
211 let (rsp_sender, rsp_recv) = ipc_channel::ipc::channel()?;
212 crate::wait(move || chan_sender.send(rsp_sender)).await.unwrap();
213
214 let requests = Arc::new(Mutex::new(IdMap::<RequestId, flume::Sender<O>>::new()));
215 let receiver = std::thread::spawn(clmv!(requests, || {
216 loop {
217 match rsp_recv.recv() {
218 Ok((id, r)) => match requests.lock().remove(&id) {
219 Some(s) => match r {
220 Response::Out(r) => {
221 let _ = s.send(r);
222 }
223 },
224 None => tracing::error!("worker responded to unknown request #{}", id.sequential()),
225 },
226 Err(e) => match e {
227 ipc_channel::ipc::IpcError::Disconnected => {
228 requests.lock().clear();
229 break;
230 }
231 ipc_channel::ipc::IpcError::Bincode(e) => {
232 tracing::error!("worker response error, will shutdown, {e}");
233 break;
234 }
235 ipc_channel::ipc::IpcError::Io(e) => {
236 tracing::error!("worker response io error, will shutdown, {e}");
237 break;
238 }
239 },
240 }
241 }
242 }));
243
244 Ok(Self {
245 running: Some((receiver, process)),
246 sender: req_sender,
247 _p: PhantomData,
248 crash: None,
249 requests,
250 })
251 }
252
253 pub async fn shutdown(mut self) -> std::io::Result<()> {
255 if let Some((receiver, process)) = self.running.take() {
256 while !self.requests.lock().is_empty() {
257 crate::deadline(100.ms()).await;
258 }
259 let r = crate::wait(move || process.kill()).await;
260
261 match crate::with_deadline(crate::wait(move || receiver.join()), 1.secs()).await {
262 Ok(r) => {
263 if let Err(p) = r {
264 tracing::error!("worker receiver thread exited panicked, {}", crate::crate_util::panic_str(&p));
265 }
266 }
267 Err(_) => {
268 if r.is_ok() {
270 panic!("worker receiver thread did not exit after worker process did");
272 }
273 }
274 }
275 r
276 } else {
277 Ok(())
278 }
279 }
280
281 pub fn run(&mut self, input: I) -> impl Future<Output = Result<O, RunError>> + Send + 'static {
283 self.run_request(Request::Run(input))
284 }
285
286 fn run_request(&mut self, request: Request<I>) -> Pin<Box<dyn Future<Output = Result<O, RunError>> + Send + 'static>> {
287 if self.crash_error().is_some() {
288 return Box::pin(std::future::ready(Err(RunError::Disconnected)));
289 }
290
291 let id = RequestId::new_unique();
292 let (sx, rx) = flume::bounded(1);
293
294 let requests = self.requests.clone();
295 requests.lock().insert(id, sx);
296 let sender = self.sender.clone();
297 let send_r = crate::wait(move || sender.send((id, request)));
298
299 Box::pin(async move {
300 if let Err(e) = send_r.await {
301 requests.lock().remove(&id);
302 return Err(RunError::Other(Arc::new(e)));
303 }
304
305 match rx.recv_async().await {
306 Ok(r) => Ok(r),
307 Err(e) => match e {
308 flume::RecvError::Disconnected => {
309 requests.lock().remove(&id);
310 Err(RunError::Disconnected)
311 }
312 },
313 }
314 })
315 }
316
317 pub fn crash_error(&mut self) -> Option<&WorkerCrashError> {
321 if let Some((t, _)) = &self.running {
322 if t.is_finished() {
323 let (t, p) = self.running.take().unwrap();
324
325 if let Err(e) = t.join() {
326 tracing::error!("panic in worker receiver thread, {}", crate::crate_util::panic_str(&e));
327 }
328
329 if let Err(e) = p.kill() {
330 tracing::error!("error killing worker process after receiver exit, {e}");
331 }
332
333 match p.into_output() {
334 Ok(o) => {
335 self.crash = Some(WorkerCrashError {
336 status: o.status,
337 stdout: String::from_utf8_lossy(&o.stdout[..]).as_ref().to_txt(),
338 stderr: String::from_utf8_lossy(&o.stderr[..]).as_ref().to_txt(),
339 });
340 }
341 Err(e) => tracing::error!("error reading crashed worker output, {e}"),
342 }
343 }
344 }
345
346 self.crash.as_ref()
347 }
348}
349impl<I: IpcValue, O: IpcValue> Drop for Worker<I, O> {
350 fn drop(&mut self) {
351 if let Some((receiver, process)) = self.running.take() {
352 if !receiver.is_finished() {
353 tracing::error!("dropped worker without shutdown");
354 }
355 if let Err(e) = process.kill() {
356 tracing::error!("failed to kill worker process on drop, {e}");
357 }
358 }
359 }
360}
361
362pub fn run_worker<I, O, F>(worker_name: impl Into<Txt>, handler: impl Fn(RequestArgs<I>) -> F + Send + Sync + 'static)
367where
368 I: IpcValue,
369 O: IpcValue,
370 F: Future<Output = O> + Send + Sync + 'static,
371{
372 let name = worker_name.into();
373 if let Some(server_name) = run_worker_server(&name) {
374 let app_init_sender = IpcSender::<WorkerInit<I, O>>::connect(server_name)
375 .unwrap_or_else(|e| panic!("failed to connect to '{name}' init channel, {e}"));
376
377 let (req_sender, req_recv) = ipc_channel::ipc::channel().unwrap();
378 let (chan_sender, chan_recv) = ipc_channel::ipc::channel().unwrap();
379
380 app_init_sender.send((req_sender, chan_sender)).unwrap();
381 let rsp_sender = chan_recv.recv().unwrap();
382 let handler = Arc::new(handler);
383
384 loop {
385 match req_recv.recv() {
386 Ok((id, input)) => match input {
387 Request::Run(r) => crate::spawn(async_clmv!(handler, rsp_sender, {
388 let output = handler(RequestArgs { request: r }).await;
389 let _ = rsp_sender.send((id, Response::Out(output)));
390 })),
391 },
392 Err(e) => match e {
393 ipc_channel::ipc::IpcError::Bincode(e) => {
394 eprintln!("worker '{name}' request error, {e}")
395 }
396 ipc_channel::ipc::IpcError::Io(e) => panic!("worker '{name}' request io error, {e}"),
397 ipc_channel::ipc::IpcError::Disconnected => break,
398 },
399 }
400 }
401
402 zng_env::exit(0);
403 }
404}
405fn run_worker_server(worker_name: &str) -> Option<String> {
406 if let (Ok(w_name), Ok(version), Ok(server_name)) = (
407 std::env::var(WORKER_NAME),
408 std::env::var(WORKER_VERSION),
409 std::env::var(WORKER_SERVER),
410 ) {
411 if w_name != worker_name {
412 return None;
413 }
414 if version != VERSION {
415 eprintln!("worker '{worker_name}' API version is not equal, app-process: {version}, worker-process: {VERSION}");
416 zng_env::exit(i32::from_le_bytes(*b"vapi"));
417 }
418
419 Some(server_name)
420 } else {
421 None
422 }
423}
424
425#[non_exhaustive]
427pub struct RequestArgs<I: IpcValue> {
428 pub request: I,
430}
431
432#[derive(Debug, Clone)]
434#[non_exhaustive]
435pub enum RunError {
436 Disconnected,
440 Other(Arc<dyn std::error::Error + Send + Sync>),
442}
443impl fmt::Display for RunError {
444 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445 match self {
446 RunError::Disconnected => write!(f, "worker process disconnected"),
447 RunError::Other(e) => write!(f, "run error, {e}"),
448 }
449 }
450}
451impl std::error::Error for RunError {}
452
453#[derive(Debug, Clone)]
455#[non_exhaustive]
456pub struct WorkerCrashError {
457 pub status: std::process::ExitStatus,
459 pub stdout: Txt,
461 pub stderr: Txt,
463}
464impl fmt::Display for WorkerCrashError {
465 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
466 write!(f, "{:?}\nSTDOUT:\n{}\nSTDERR:\n{}", self.status, &self.stdout, &self.stderr)
467 }
468}
469impl std::error::Error for WorkerCrashError {}
470
471#[derive(serde::Serialize, serde::Deserialize)]
472enum Request<I> {
473 Run(I),
474}
475
476#[derive(serde::Serialize, serde::Deserialize)]
477enum Response<O> {
478 Out(O),
479}
480
481type WorkerInit<I, O> = (IpcSender<(RequestId, Request<I>)>, IpcSender<IpcSender<(RequestId, Response<O>)>>);
490
491zng_unique_id::unique_id_64! {
492 #[derive(serde::Serialize, serde::Deserialize)]
493 struct RequestId;
494}