1use std::collections::HashMap;
10use std::io::{self, ErrorKind};
11use std::path::Path;
12use std::process::Stdio;
13use std::sync::Arc;
14use std::sync::Mutex as StdMutex;
15use std::sync::atomic::AtomicBool;
16
17use anyhow::{Context, Result};
18use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
19use tokio::process::Command;
20use tokio::sync::{broadcast, mpsc, oneshot};
21use tokio::task::JoinHandle;
22
23use crate::process::{ChildTerminator, ProcessHandle, SpawnedProcess};
24use crate::process_group;
25
26struct PipeChildTerminator {
28 #[cfg(windows)]
29 pid: u32,
30 #[cfg(unix)]
31 process_group_id: u32,
32}
33
34impl ChildTerminator for PipeChildTerminator {
35 fn kill(&mut self) -> io::Result<()> {
36 #[cfg(unix)]
37 {
38 process_group::kill_process_group(self.process_group_id)
39 }
40
41 #[cfg(windows)]
42 {
43 process_group::kill_process(self.pid)
44 }
45
46 #[cfg(not(any(unix, windows)))]
47 {
48 Ok(())
49 }
50 }
51}
52
53async fn read_output_stream<R>(mut reader: R, output_tx: broadcast::Sender<Vec<u8>>)
55where
56 R: AsyncRead + Unpin,
57{
58 let mut buf = vec![0u8; 8_192];
59 loop {
60 match reader.read(&mut buf).await {
61 Ok(0) => break,
62 Ok(n) => {
63 let _ = output_tx.send(buf[..n].to_vec());
64 }
65 Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
66 Err(_) => break,
67 }
68 }
69}
70
71#[derive(Clone, Copy)]
73pub enum PipeStdinMode {
74 Piped,
76 Null,
78}
79
80#[derive(Clone)]
82pub struct PipeSpawnOptions {
83 pub program: String,
85 pub args: Vec<String>,
87 pub cwd: std::path::PathBuf,
89 pub env: Option<HashMap<String, String>>,
91 pub arg0: Option<String>,
93 pub stdin_mode: PipeStdinMode,
95}
96
97impl PipeSpawnOptions {
98 pub fn new(program: impl Into<String>, cwd: impl Into<std::path::PathBuf>) -> Self {
100 Self {
101 program: program.into(),
102 args: Vec::new(),
103 cwd: cwd.into(),
104 env: None,
105 arg0: None,
106 stdin_mode: PipeStdinMode::Piped,
107 }
108 }
109
110 pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
112 self.args = args.into_iter().map(Into::into).collect();
113 self
114 }
115
116 pub fn env(mut self, env: HashMap<String, String>) -> Self {
118 self.env = Some(env);
119 self
120 }
121
122 pub fn arg0(mut self, arg0: impl Into<String>) -> Self {
124 self.arg0 = Some(arg0.into());
125 self
126 }
127
128 pub fn stdin_mode(mut self, mode: PipeStdinMode) -> Self {
130 self.stdin_mode = mode;
131 self
132 }
133}
134
135async fn spawn_process_internal(opts: PipeSpawnOptions) -> Result<SpawnedProcess> {
137 if opts.program.is_empty() {
138 anyhow::bail!("missing program for pipe spawn");
139 }
140
141 let mut command = Command::new(&opts.program);
142
143 #[cfg(unix)]
144 if let Some(ref arg0) = opts.arg0 {
145 command.arg0(arg0);
146 }
147
148 #[cfg(unix)]
149 {
150 command.process_group(0);
151 }
152
153 #[cfg(not(unix))]
154 let _ = &opts.arg0;
155
156 command.current_dir(&opts.cwd);
157
158 if let Some(ref env) = opts.env {
160 command.env_clear();
161 for (key, value) in env {
162 command.env(key, value);
163 }
164 }
165
166 for arg in &opts.args {
167 command.arg(arg);
168 }
169
170 match opts.stdin_mode {
171 PipeStdinMode::Piped => {
172 command.stdin(Stdio::piped());
173 }
174 PipeStdinMode::Null => {
175 command.stdin(Stdio::null());
176 }
177 }
178 command.stdout(Stdio::piped());
179 command.stderr(Stdio::piped());
180
181 let mut child = command.spawn().context("failed to spawn pipe process")?;
182 let pid = child
183 .id()
184 .ok_or_else(|| io::Error::other("missing child pid"))?;
185
186 #[cfg(unix)]
187 let process_group_id = pid;
188
189 let stdin = child.stdin.take();
190 let stdout = child.stdout.take();
191 let stderr = child.stderr.take();
192
193 let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
194 let (output_tx, _) = broadcast::channel::<Vec<u8>>(256);
195 let initial_output_rx = output_tx.subscribe();
196
197 let writer_handle = if let Some(stdin) = stdin {
199 let writer = Arc::new(tokio::sync::Mutex::new(stdin));
200 tokio::spawn(async move {
201 while let Some(bytes) = writer_rx.recv().await {
202 let mut guard = writer.lock().await;
203 let _ = guard.write_all(&bytes).await;
204 let _ = guard.flush().await;
205 }
206 })
207 } else {
208 drop(writer_rx);
209 tokio::spawn(async {})
210 };
211
212 let stdout_handle = stdout.map(|stdout| {
214 let output_tx = output_tx.clone();
215 tokio::spawn(async move {
216 read_output_stream(BufReader::new(stdout), output_tx).await;
217 })
218 });
219
220 let stderr_handle = stderr.map(|stderr| {
221 let output_tx = output_tx.clone();
222 tokio::spawn(async move {
223 read_output_stream(BufReader::new(stderr), output_tx).await;
224 })
225 });
226
227 let mut reader_abort_handles = Vec::new();
228 if let Some(ref handle) = stdout_handle {
229 reader_abort_handles.push(handle.abort_handle());
230 }
231 if let Some(ref handle) = stderr_handle {
232 reader_abort_handles.push(handle.abort_handle());
233 }
234
235 let reader_handle = tokio::spawn(async move {
236 if let Some(handle) = stdout_handle {
237 let _ = handle.await;
238 }
239 if let Some(handle) = stderr_handle {
240 let _ = handle.await;
241 }
242 });
243
244 let (exit_tx, exit_rx) = oneshot::channel::<i32>();
246 let exit_status = Arc::new(AtomicBool::new(false));
247 let wait_exit_status = Arc::clone(&exit_status);
248 let exit_code = Arc::new(StdMutex::new(None));
249 let wait_exit_code = Arc::clone(&exit_code);
250
251 let wait_handle: JoinHandle<()> = tokio::spawn(async move {
252 let code = match child.wait().await {
253 Ok(status) => status.code().unwrap_or(-1),
254 Err(_) => -1,
255 };
256 wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
257 if let Ok(mut guard) = wait_exit_code.lock() {
258 *guard = Some(code);
259 }
260 let _ = exit_tx.send(code);
261 });
262
263 let (handle, output_rx) = ProcessHandle::new(
264 writer_tx,
265 output_tx,
266 initial_output_rx,
267 Box::new(PipeChildTerminator {
268 #[cfg(windows)]
269 pid,
270 #[cfg(unix)]
271 process_group_id,
272 }),
273 reader_handle,
274 reader_abort_handles,
275 writer_handle,
276 wait_handle,
277 exit_status,
278 exit_code,
279 None,
280 );
281
282 Ok(SpawnedProcess {
283 session: handle,
284 output_rx,
285 exit_rx,
286 })
287}
288
289pub async fn spawn_process(
303 program: &str,
304 args: &[String],
305 cwd: &Path,
306 env: &HashMap<String, String>,
307 arg0: &Option<String>,
308) -> Result<SpawnedProcess> {
309 let opts = PipeSpawnOptions {
310 program: program.to_string(),
311 args: args.to_vec(),
312 cwd: cwd.to_path_buf(),
313 env: Some(env.clone()),
314 arg0: arg0.clone(),
315 stdin_mode: PipeStdinMode::Piped,
316 };
317 spawn_process_internal(opts).await
318}
319
320pub async fn spawn_process_no_stdin(
324 program: &str,
325 args: &[String],
326 cwd: &Path,
327 env: &HashMap<String, String>,
328 arg0: &Option<String>,
329) -> Result<SpawnedProcess> {
330 let opts = PipeSpawnOptions {
331 program: program.to_string(),
332 args: args.to_vec(),
333 cwd: cwd.to_path_buf(),
334 env: Some(env.clone()),
335 arg0: arg0.clone(),
336 stdin_mode: PipeStdinMode::Null,
337 };
338 spawn_process_internal(opts).await
339}
340
341pub async fn spawn_process_with_options(opts: PipeSpawnOptions) -> Result<SpawnedProcess> {
343 spawn_process_internal(opts).await
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 fn find_echo_command() -> Option<(String, Vec<String>)> {
351 #[cfg(windows)]
352 {
353 Some((
354 "cmd.exe".to_string(),
355 vec!["/C".to_string(), "echo".to_string()],
356 ))
357 }
358 #[cfg(not(windows))]
359 {
360 Some(("echo".to_string(), vec![]))
361 }
362 }
363
364 #[tokio::test]
365 async fn test_spawn_process_echo() -> Result<()> {
366 let Some((program, mut base_args)) = find_echo_command() else {
367 return Ok(());
368 };
369
370 base_args.push("hello".to_string());
371
372 let env: HashMap<String, String> = std::env::vars().collect();
373 let spawned = spawn_process(&program, &base_args, Path::new("."), &env, &None).await?;
374
375 let exit_code = spawned.exit_rx.await.unwrap_or(-1);
376 assert_eq!(exit_code, 0);
377
378 Ok(())
379 }
380
381 #[tokio::test]
382 async fn test_spawn_options_builder() {
383 let opts = PipeSpawnOptions::new("echo", ".")
384 .args(["hello", "world"])
385 .stdin_mode(PipeStdinMode::Null);
386
387 assert_eq!(opts.program, "echo");
388 assert_eq!(opts.args, vec!["hello", "world"]);
389 assert!(matches!(opts.stdin_mode, PipeStdinMode::Null));
390 }
391}