1use std::io::Write;
2use std::process::{Command, Stdio};
3use std::time::Duration;
4use thiserror::Error;
5use wait_timeout::ChildExt;
6
7#[derive(Debug, Error)]
8pub enum ShellError {
9 #[error("IO error: {0}")]
10 Io(#[from] std::io::Error),
11 #[error("Command `{0}` timed out after {1:?}")]
12 Timeout(String, Duration),
13 #[error("Command `{0}` failed with status {1}")]
14 CommandFailed(String, std::process::ExitStatus),
15 #[error("Command output was not valid UTF-8")]
16 InvalidUtf8(#[from] std::string::FromUtf8Error),
17}
18
19pub fn run_piped(
34 command_str: &str,
35 input: &str,
36 timeout: Option<Duration>,
37) -> Result<String, ShellError> {
38 let mut cmd = if cfg!(target_os = "windows") {
39 let mut c = Command::new("cmd");
40 c.arg("/C").arg(command_str);
41 c
42 } else {
43 let mut c = Command::new("sh");
44 c.arg("-c").arg(command_str);
45 c
46 };
47
48 cmd.stdin(Stdio::piped())
49 .stdout(Stdio::piped())
50 .stderr(Stdio::inherit());
51
52 let mut child = cmd.spawn()?;
53
54 if let Some(mut stdin) = child.stdin.take() {
55 stdin.write_all(input.as_bytes())?;
56 }
57
58 match timeout {
59 Some(duration) => match child.wait_timeout(duration)? {
60 Some(status) => {
61 if !status.success() {
62 return Err(ShellError::CommandFailed(command_str.to_string(), status));
63 }
64 }
65 None => {
66 child.kill()?;
67 return Err(ShellError::Timeout(command_str.to_string(), duration));
68 }
69 },
70 None => {
71 let status = child.wait()?;
72 if !status.success() {
73 return Err(ShellError::CommandFailed(command_str.to_string(), status));
74 }
75 }
76 }
77
78 let mut output = String::new();
79 if let Some(mut stdout) = child.stdout.take() {
80 use std::io::Read;
81 stdout.read_to_string(&mut output)?;
82 }
83
84 Ok(output)
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_echo() {
93 let cmd = "echo hello";
94 let output = run_piped(cmd, "", None).unwrap();
95 assert!(output.trim().contains("hello"));
96 }
97
98 #[test]
99 fn test_input_piping() {
100 let cmd = if cfg!(windows) {
101 "findstr foo"
102 } else {
103 "grep foo"
104 };
105 let input = "foo\nbar\nbaz";
106 let output = run_piped(cmd, input, None).unwrap();
107 assert_eq!(output.trim(), "foo");
108 }
109
110 #[test]
111 fn test_timeout() {
112 let cmd = if cfg!(windows) {
113 "ping -n 3 127.0.0.1"
114 } else {
115 "sleep 2"
116 };
117 let start = std::time::Instant::now();
118 let res = run_piped(cmd, "", Some(Duration::from_millis(500)));
119 assert!(matches!(res, Err(ShellError::Timeout(_, _))));
120 assert!(start.elapsed() < Duration::from_secs(2));
121 }
122
123 #[test]
124 fn test_command_failed_includes_command_name() {
125 let cmd = "exit 1";
126 let res = run_piped(cmd, "", None);
127 match res {
128 Err(ShellError::CommandFailed(cmd_str, _)) => {
129 assert_eq!(cmd_str, cmd);
130 }
131 _ => panic!("Expected CommandFailed error"),
132 }
133 }
134}