1use std::collections::HashMap;
10use std::io::{self, ErrorKind};
11use std::path::Path;
12use std::process::Stdio;
13use std::sync::Arc;
14use std::sync::Mutex as StdMutex;
15use std::sync::atomic::AtomicBool;
16
17use anyhow::{Context, Result};
18use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
19use tokio::process::Command;
20use tokio::sync::{broadcast, mpsc, oneshot};
21use tokio::task::JoinHandle;
22
23use crate::process::{ChildTerminator, ProcessHandle, SpawnedProcess};
24use crate::process_group;
25
26#[cfg(target_os = "linux")]
27use libc;
28
29struct PipeChildTerminator {
31 #[cfg(windows)]
32 pid: u32,
33 #[cfg(unix)]
34 process_group_id: u32,
35}
36
37impl ChildTerminator for PipeChildTerminator {
38 fn kill(&mut self) -> io::Result<()> {
39 #[cfg(unix)]
40 {
41 process_group::kill_process_group(self.process_group_id)
42 }
43
44 #[cfg(windows)]
45 {
46 process_group::kill_process(self.pid)
47 }
48
49 #[cfg(not(any(unix, windows)))]
50 {
51 Ok(())
52 }
53 }
54}
55
56async fn read_output_stream<R>(mut reader: R, output_tx: broadcast::Sender<Vec<u8>>)
58where
59 R: AsyncRead + Unpin,
60{
61 let mut buf = vec![0u8; 8_192];
62 loop {
63 match reader.read(&mut buf).await {
64 Ok(0) => break,
65 Ok(n) => {
66 let _ = output_tx.send(buf[..n].to_vec());
67 }
68 Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
69 Err(_) => break,
70 }
71 }
72}
73
74#[derive(Clone, Copy)]
76pub enum PipeStdinMode {
77 Piped,
79 Null,
81}
82
83#[derive(Clone)]
85pub struct PipeSpawnOptions {
86 pub program: String,
88 pub args: Vec<String>,
90 pub cwd: std::path::PathBuf,
92 pub env: Option<HashMap<String, String>>,
94 pub arg0: Option<String>,
96 pub stdin_mode: PipeStdinMode,
98}
99
100impl PipeSpawnOptions {
101 pub fn new(program: impl Into<String>, cwd: impl Into<std::path::PathBuf>) -> Self {
103 Self {
104 program: program.into(),
105 args: Vec::new(),
106 cwd: cwd.into(),
107 env: None,
108 arg0: None,
109 stdin_mode: PipeStdinMode::Piped,
110 }
111 }
112
113 pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
115 self.args = args.into_iter().map(Into::into).collect();
116 self
117 }
118
119 pub fn env(mut self, env: HashMap<String, String>) -> Self {
121 self.env = Some(env);
122 self
123 }
124
125 pub fn arg0(mut self, arg0: impl Into<String>) -> Self {
127 self.arg0 = Some(arg0.into());
128 self
129 }
130
131 pub fn stdin_mode(mut self, mode: PipeStdinMode) -> Self {
133 self.stdin_mode = mode;
134 self
135 }
136}
137
138async fn spawn_process_internal(opts: PipeSpawnOptions) -> Result<SpawnedProcess> {
140 if opts.program.is_empty() {
141 anyhow::bail!("missing program for pipe spawn");
142 }
143
144 let mut command = Command::new(&opts.program);
145
146 #[cfg(unix)]
147 if let Some(ref arg0) = opts.arg0 {
148 command.arg0(arg0);
149 }
150
151 #[cfg(target_os = "linux")]
152 let parent_pid = unsafe { libc::getpid() };
153
154 #[cfg(unix)]
155 unsafe {
156 command.pre_exec(move || {
157 process_group::detach_from_tty()?;
158 #[cfg(target_os = "linux")]
159 process_group::set_parent_death_signal(parent_pid)?;
160 Ok(())
161 });
162 }
163
164 #[cfg(not(unix))]
165 let _ = &opts.arg0;
166
167 command.current_dir(&opts.cwd);
168
169 if let Some(ref env) = opts.env {
171 command.env_clear();
172 for (key, value) in env {
173 command.env(key, value);
174 }
175 }
176
177 for arg in &opts.args {
178 command.arg(arg);
179 }
180
181 match opts.stdin_mode {
182 PipeStdinMode::Piped => {
183 command.stdin(Stdio::piped());
184 }
185 PipeStdinMode::Null => {
186 command.stdin(Stdio::null());
187 }
188 }
189 command.stdout(Stdio::piped());
190 command.stderr(Stdio::piped());
191
192 let mut child = command.spawn().context("failed to spawn pipe process")?;
193 let pid = child
194 .id()
195 .ok_or_else(|| io::Error::other("missing child pid"))?;
196
197 #[cfg(unix)]
198 let process_group_id = pid;
199
200 let stdin = child.stdin.take();
201 let stdout = child.stdout.take();
202 let stderr = child.stderr.take();
203
204 let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
205 let (output_tx, _) = broadcast::channel::<Vec<u8>>(256);
206 let initial_output_rx = output_tx.subscribe();
207
208 let writer_handle = if let Some(stdin) = stdin {
210 let writer = Arc::new(tokio::sync::Mutex::new(stdin));
211 tokio::spawn(async move {
212 while let Some(bytes) = writer_rx.recv().await {
213 let mut guard = writer.lock().await;
214 let _ = guard.write_all(&bytes).await;
215 let _ = guard.flush().await;
216 }
217 })
218 } else {
219 drop(writer_rx);
220 tokio::spawn(async {})
221 };
222
223 let stdout_handle = stdout.map(|stdout| {
225 let output_tx = output_tx.clone();
226 tokio::spawn(async move {
227 read_output_stream(BufReader::new(stdout), output_tx).await;
228 })
229 });
230
231 let stderr_handle = stderr.map(|stderr| {
232 let output_tx = output_tx.clone();
233 tokio::spawn(async move {
234 read_output_stream(BufReader::new(stderr), output_tx).await;
235 })
236 });
237
238 let mut reader_abort_handles = Vec::new();
239 if let Some(ref handle) = stdout_handle {
240 reader_abort_handles.push(handle.abort_handle());
241 }
242 if let Some(ref handle) = stderr_handle {
243 reader_abort_handles.push(handle.abort_handle());
244 }
245
246 let reader_handle = tokio::spawn(async move {
247 if let Some(handle) = stdout_handle {
248 let _ = handle.await;
249 }
250 if let Some(handle) = stderr_handle {
251 let _ = handle.await;
252 }
253 });
254
255 let (exit_tx, exit_rx) = oneshot::channel::<i32>();
257 let exit_status = Arc::new(AtomicBool::new(false));
258 let wait_exit_status = Arc::clone(&exit_status);
259 let exit_code = Arc::new(StdMutex::new(None));
260 let wait_exit_code = Arc::clone(&exit_code);
261
262 let wait_handle: JoinHandle<()> = tokio::spawn(async move {
263 let code = match child.wait().await {
264 Ok(status) => status.code().unwrap_or(-1),
265 Err(_) => -1,
266 };
267 wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
268 if let Ok(mut guard) = wait_exit_code.lock() {
269 *guard = Some(code);
270 }
271 let _ = exit_tx.send(code);
272 });
273
274 let (handle, output_rx) = ProcessHandle::new(
275 writer_tx,
276 output_tx,
277 initial_output_rx,
278 Box::new(PipeChildTerminator {
279 #[cfg(windows)]
280 pid,
281 #[cfg(unix)]
282 process_group_id,
283 }),
284 reader_handle,
285 reader_abort_handles,
286 writer_handle,
287 wait_handle,
288 exit_status,
289 exit_code,
290 None,
291 );
292
293 Ok(SpawnedProcess {
294 session: handle,
295 output_rx,
296 exit_rx,
297 })
298}
299
300pub async fn spawn_process(
314 program: &str,
315 args: &[String],
316 cwd: &Path,
317 env: &HashMap<String, String>,
318 arg0: &Option<String>,
319) -> Result<SpawnedProcess> {
320 let opts = PipeSpawnOptions {
321 program: program.to_string(),
322 args: args.to_vec(),
323 cwd: cwd.to_path_buf(),
324 env: Some(env.clone()),
325 arg0: arg0.clone(),
326 stdin_mode: PipeStdinMode::Piped,
327 };
328 spawn_process_internal(opts).await
329}
330
331pub async fn spawn_process_no_stdin(
335 program: &str,
336 args: &[String],
337 cwd: &Path,
338 env: &HashMap<String, String>,
339 arg0: &Option<String>,
340) -> Result<SpawnedProcess> {
341 let opts = PipeSpawnOptions {
342 program: program.to_string(),
343 args: args.to_vec(),
344 cwd: cwd.to_path_buf(),
345 env: Some(env.clone()),
346 arg0: arg0.clone(),
347 stdin_mode: PipeStdinMode::Null,
348 };
349 spawn_process_internal(opts).await
350}
351
352pub async fn spawn_process_with_options(opts: PipeSpawnOptions) -> Result<SpawnedProcess> {
354 spawn_process_internal(opts).await
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 fn find_echo_command() -> Option<(String, Vec<String>)> {
362 #[cfg(windows)]
363 {
364 Some((
365 "cmd.exe".to_string(),
366 vec!["/C".to_string(), "echo".to_string()],
367 ))
368 }
369 #[cfg(not(windows))]
370 {
371 Some(("echo".to_string(), vec![]))
372 }
373 }
374
375 #[tokio::test]
376 async fn test_spawn_process_echo() -> anyhow::Result<()> {
377 let Some((program, mut base_args)) = find_echo_command() else {
378 return Ok(());
379 };
380
381 base_args.push("hello".to_string());
382
383 let env: HashMap<String, String> = std::env::vars().collect();
384 let spawned = spawn_process(&program, &base_args, Path::new("."), &env, &None).await?;
385
386 let exit_code = spawned.exit_rx.await.unwrap_or(-1);
387 assert_eq!(exit_code, 0);
388
389 Ok(())
390 }
391
392 #[tokio::test]
393 async fn test_spawn_options_builder() {
394 let opts = PipeSpawnOptions::new("echo", ".")
395 .args(["hello", "world"])
396 .stdin_mode(PipeStdinMode::Null);
397
398 assert_eq!(opts.program, "echo");
399 assert_eq!(opts.args, vec!["hello", "world"]);
400 assert!(matches!(opts.stdin_mode, PipeStdinMode::Null));
401 }
402}