1use std::io::Write;
2use std::process::{Command, Stdio};
3use std::time::Duration;
4use thiserror::Error;
5use wait_timeout::ChildExt;
6
7#[derive(Debug, Error)]
8pub enum ShellError {
9 #[error("IO error: {0}")]
10 Io(#[from] std::io::Error),
11 #[error("Command `{0}` timed out after {1:?}")]
12 Timeout(String, Duration),
13 #[error("Command `{0}` failed with status {1}")]
14 CommandFailed(String, std::process::ExitStatus),
15 #[error("Command output was not valid UTF-8")]
16 InvalidUtf8(#[from] std::string::FromUtf8Error),
17}
18
19pub fn run_piped(
34 command_str: &str,
35 input: &str,
36 timeout: Option<Duration>,
37) -> Result<String, ShellError> {
38 let mut cmd = if cfg!(target_os = "windows") {
39 let mut c = Command::new("cmd");
40 c.arg("/C").arg(command_str);
41 c
42 } else {
43 let mut c = Command::new("sh");
44 c.arg("-c").arg(command_str);
45 c
46 };
47
48 cmd.stdin(Stdio::piped())
49 .stdout(Stdio::piped())
50 .stderr(Stdio::inherit());
51
52 let mut child = cmd.spawn()?;
53
54 if let Some(mut stdin) = child.stdin.take() {
55 stdin.write_all(input.as_bytes())?;
56 }
57
58 match timeout {
59 Some(duration) => match child.wait_timeout(duration)? {
60 Some(status) => {
61 if !status.success() {
62 return Err(ShellError::CommandFailed(command_str.to_string(), status));
63 }
64 }
65 None => {
66 child.kill()?;
67 return Err(ShellError::Timeout(command_str.to_string(), duration));
68 }
69 },
70 None => {
71 let status = child.wait()?;
72 if !status.success() {
73 return Err(ShellError::CommandFailed(command_str.to_string(), status));
74 }
75 }
76 }
77
78 let mut output = String::new();
79 if let Some(mut stdout) = child.stdout.take() {
80 use std::io::Read;
81 stdout.read_to_string(&mut output)?;
82 }
83
84 Ok(output)
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_echo() {
93 let cmd = if cfg!(windows) {
94 "echo hello"
95 } else {
96 "echo hello"
97 };
98 let output = run_piped(cmd, "", None).unwrap();
99 assert!(output.trim().contains("hello"));
100 }
101
102 #[test]
103 fn test_input_piping() {
104 let cmd = if cfg!(windows) {
105 "findstr foo"
106 } else {
107 "grep foo"
108 };
109 let input = "foo\nbar\nbaz";
110 let output = run_piped(cmd, input, None).unwrap();
111 assert_eq!(output.trim(), "foo");
112 }
113
114 #[test]
115 fn test_timeout() {
116 let cmd = if cfg!(windows) {
117 "ping -n 3 127.0.0.1"
118 } else {
119 "sleep 2"
120 };
121 let start = std::time::Instant::now();
122 let res = run_piped(cmd, "", Some(Duration::from_millis(500)));
123 assert!(matches!(res, Err(ShellError::Timeout(_, _))));
124 assert!(start.elapsed() < Duration::from_secs(2));
125 }
126
127 #[test]
128 fn test_command_failed_includes_command_name() {
129 let cmd = if cfg!(windows) { "exit 1" } else { "exit 1" };
130 let res = run_piped(cmd, "", None);
131 match res {
132 Err(ShellError::CommandFailed(cmd_str, _)) => {
133 assert_eq!(cmd_str, cmd);
134 }
135 _ => panic!("Expected CommandFailed error"),
136 }
137 }
138}