1use std::{
2 collections::HashMap,
3 env,
4 ffi::{OsStr, OsString},
5 path::{Path, PathBuf},
6 process::Command,
7 sync::{Arc, Mutex},
8 time::Duration,
9};
10
11use crate::{Error, Result, ShellTaskLog};
12use crossbeam_channel::{unbounded, Receiver, Sender};
13
14mod behavior;
15mod output;
16mod runner;
17
18pub use behavior::ShellTaskBehavior;
19pub use output::ShellTaskOutput;
20use runner::ShellTaskRunner;
21
22#[derive(Debug)]
25pub struct ShellTask {
26 bin: String,
27 args: Vec<String>,
28 current_dir: PathBuf,
29 envs: HashMap<OsString, OsString>,
30 full_command: String,
31 log_sender: Sender<ShellTaskLog>,
32 log_receiver: Receiver<ShellTaskLog>,
33}
34
35impl ShellTask {
36 pub fn new(command: &str) -> Result<Self> {
38 let current_dir =
39 env::current_dir().map_err(|source| Error::CouldNotFindCurrentDirectory { source })?;
40 let command = command.to_string();
41 let args: Vec<&str> = command.split(' ').collect();
42 let (bin, args) = match args.len() {
43 0 => Err(Error::InvalidTask {
44 task: command.to_string(),
45 reason: "an empty string is not a command".to_string(),
46 }),
47 1 => Ok((args[0], Vec::new())),
48 _ => Ok((args[0], Vec::from_iter(args[1..].iter()))),
49 }?;
50
51 if which::which(bin).is_err() {
52 Err(Error::InvalidTask {
53 task: command.to_string(),
54 reason: format!("'{}' is not installed on this machine", &bin),
55 })
56 } else {
57 let (log_sender, log_receiver) = unbounded();
58 Ok(Self {
59 bin: bin.to_string(),
60 args: args.iter().map(|s| s.to_string()).collect(),
61 full_command: command,
62 envs: HashMap::new(),
63 current_dir,
64 log_sender,
65 log_receiver,
66 })
67 }
68 }
69
70 pub fn env<K, V>(&mut self, key: K, value: V) -> &mut ShellTask
72 where
73 K: AsRef<OsStr>,
74 V: AsRef<OsStr>,
75 {
76 self.envs
77 .insert(key.as_ref().to_os_string(), value.as_ref().to_os_string());
78 self
79 }
80
81 pub fn current_dir<P>(&mut self, path: P)
83 where
84 P: AsRef<Path>,
85 {
86 self.current_dir = path.as_ref().to_path_buf();
87 }
88
89 pub fn descriptor(&self) -> String {
91 self.full_command.to_string()
92 }
93
94 pub fn bash_descriptor(&self) -> String {
96 format!("$ {}", self.descriptor())
97 }
98
99 fn get_command(&self) -> Command {
101 let mut command = Command::new(&self.bin);
102 command
103 .args(&self.args)
104 .envs(&self.envs)
105 .current_dir(&self.current_dir);
106 command
107 }
108
109 pub fn run<F, T>(&self, log_handler: F) -> Result<ShellTaskOutput<T>>
161 where
162 F: Fn(ShellTaskLog) -> ShellTaskBehavior<T> + Send + Sync + 'static,
163 T: Send + Sync + 'static,
164 {
165 let log_drain: Arc<Mutex<Vec<ShellTaskLog>>> = Arc::new(Mutex::new(Vec::new()));
166 let log_drainer = log_drain.clone();
167 let log_drain_filler = log_drain.clone();
168 let log_receiver = self.log_receiver.clone();
169 let full_command = self.full_command.to_string();
170
171 let maybe_result = Arc::new(Mutex::new(None));
172 let early_terminator = maybe_result.clone();
173
174 let collected_stdout_lines = Arc::new(Mutex::new(Vec::new()));
175 let collected_stderr_lines = Arc::new(Mutex::new(Vec::new()));
176 let stdout_collector = collected_stdout_lines.clone();
177 let stderr_collector = collected_stderr_lines.clone();
178
179 rayon::spawn(move || {
180 while let Ok(line) = log_receiver.recv() {
181 match &line {
182 ShellTaskLog::Stderr(stderr) => {
183 if let Ok(mut stderr_lines) = stderr_collector.clone().lock() {
184 stderr_lines.push(stderr.to_string())
185 }
186 }
187 ShellTaskLog::Stdout(stdout) => {
188 if let Ok(mut stdout_lines) = stdout_collector.clone().lock() {
189 stdout_lines.push(stdout.to_string())
190 }
191 }
192 }
193
194 if let Ok(mut log_decrementer) = log_drainer.clone().lock() {
195 if let Some(stderr_pos) = log_decrementer
196 .iter()
197 .position(|e| matches!(e, ShellTaskLog::Stderr(_)))
198 {
199 log_decrementer.remove(stderr_pos);
200 } else if let Some(stdout_pos) = log_decrementer
201 .iter()
202 .position(|e| matches!(e, ShellTaskLog::Stdout(_)))
203 {
204 log_decrementer.remove(stdout_pos);
205 }
206 match (log_handler)(line) {
207 ShellTaskBehavior::EarlyReturn(early_return) => {
208 if let Ok(mut maybe_result) = early_terminator.lock() {
209 if maybe_result.is_none() {
210 *maybe_result = Some(early_return);
211 break;
212 }
213 }
214 }
215 ShellTaskBehavior::Passthrough => continue,
216 }
217 } else if let Ok(mut maybe_result) = early_terminator.lock() {
218 if maybe_result.is_none() {
219 *maybe_result =
220 Some(Err(Box::new(Error::PoisonedLog { task: full_command })));
221 break;
222 }
223 } else {
224 continue;
225 }
226 }
227 });
228
229 let task = ShellTaskRunner::run(
230 self.get_command(),
231 self.full_command.to_string(),
232 self.log_sender.clone(),
233 log_drain_filler,
234 )?;
235
236 let output = task
237 .child
238 .wait_with_output()
239 .map_err(|source| Error::CouldNotWait {
240 task: self.full_command.to_string(),
241 source,
242 })?;
243
244 loop {
246 std::thread::sleep(Duration::from_millis(200));
247 match log_drain.try_lock() {
248 Ok(log_drain) => {
249 if log_drain.is_empty() {
250 break;
251 } else {
252 continue;
253 }
254 }
255 _ => continue,
256 }
257 }
258
259 if output.status.success() {
260 let collected_stderr_lines = collected_stderr_lines.lock().unwrap().to_vec();
261 let collected_stdout_lines = collected_stdout_lines.lock().unwrap().to_vec();
262 if let Some(result) = maybe_result.clone().lock().unwrap().take() {
263 result
264 .map(|t| ShellTaskOutput::EarlyReturn {
265 stderr_lines: collected_stderr_lines,
266 stdout_lines: collected_stdout_lines,
267 return_value: t,
268 })
269 .map_err(|e| e.into())
270 } else {
271 Ok(ShellTaskOutput::CompleteOutput {
272 status: output.status,
273 stdout_lines: collected_stdout_lines,
274 stderr_lines: collected_stderr_lines,
275 })
276 }
277 } else {
278 Err(Error::TaskFailure {
279 task: self.full_command.to_string(),
280 exit_status: output.status,
281 })
282 }
283 }
284}