shell_candy/task/
mod.rs

1use std::{
2    collections::HashMap,
3    env,
4    ffi::{OsStr, OsString},
5    path::{Path, PathBuf},
6    process::Command,
7    sync::{Arc, Mutex},
8    time::Duration,
9};
10
11use crate::{Error, Result, ShellTaskLog};
12use crossbeam_channel::{unbounded, Receiver, Sender};
13
14mod behavior;
15mod output;
16mod runner;
17
18pub use behavior::ShellTaskBehavior;
19pub use output::ShellTaskOutput;
20use runner::ShellTaskRunner;
21
22/// A [`ShellTask`] runs commands and provides a passthrough log handler
23/// for each log line.
24#[derive(Debug)]
25pub struct ShellTask {
26    bin: String,
27    args: Vec<String>,
28    current_dir: PathBuf,
29    envs: HashMap<OsString, OsString>,
30    full_command: String,
31    log_sender: Sender<ShellTaskLog>,
32    log_receiver: Receiver<ShellTaskLog>,
33}
34
35impl ShellTask {
36    /// Create a new [`ShellTask`] with a log line handler.
37    pub fn new(command: &str) -> Result<Self> {
38        let current_dir =
39            env::current_dir().map_err(|source| Error::CouldNotFindCurrentDirectory { source })?;
40        let command = command.to_string();
41        let args: Vec<&str> = command.split(' ').collect();
42        let (bin, args) = match args.len() {
43            0 => Err(Error::InvalidTask {
44                task: command.to_string(),
45                reason: "an empty string is not a command".to_string(),
46            }),
47            1 => Ok((args[0], Vec::new())),
48            _ => Ok((args[0], Vec::from_iter(args[1..].iter()))),
49        }?;
50
51        if which::which(bin).is_err() {
52            Err(Error::InvalidTask {
53                task: command.to_string(),
54                reason: format!("'{}' is not installed on this machine", &bin),
55            })
56        } else {
57            let (log_sender, log_receiver) = unbounded();
58            Ok(Self {
59                bin: bin.to_string(),
60                args: args.iter().map(|s| s.to_string()).collect(),
61                full_command: command,
62                envs: HashMap::new(),
63                current_dir,
64                log_sender,
65                log_receiver,
66            })
67        }
68    }
69
70    /// Adds an environment variable to the command run by [`ShellTask`].
71    pub fn env<K, V>(&mut self, key: K, value: V) -> &mut ShellTask
72    where
73        K: AsRef<OsStr>,
74        V: AsRef<OsStr>,
75    {
76        self.envs
77            .insert(key.as_ref().to_os_string(), value.as_ref().to_os_string());
78        self
79    }
80
81    /// Sets the directory the command should be run in.
82    pub fn current_dir<P>(&mut self, path: P)
83    where
84        P: AsRef<Path>,
85    {
86        self.current_dir = path.as_ref().to_path_buf();
87    }
88
89    /// Returns the full command that was used to instantiate this [`ShellTask`].
90    pub fn descriptor(&self) -> String {
91        self.full_command.to_string()
92    }
93
94    /// Returns the [`ShellTask::descriptor`] with the classic `$` shell prefix.
95    pub fn bash_descriptor(&self) -> String {
96        format!("$ {}", self.descriptor())
97    }
98
99    /// Returns the [`ShellTaskRunner`] from the internal configuration.
100    fn get_command(&self) -> Command {
101        let mut command = Command::new(&self.bin);
102        command
103            .args(&self.args)
104            .envs(&self.envs)
105            .current_dir(&self.current_dir);
106        command
107    }
108
109    /// Run a [`ShellTask`], applying the log handler to each line.
110    ///
111    /// You can make the task terminate early if your `log_handler`
112    /// returns [`ShellTaskBehavior::EarlyReturn<T>`]. When this variant
113    /// is returned from a log handler, [`ShellTask::run`] will return [`Some<T>`].
114    ///
115    /// # Example
116    ///
117    /// ```
118    /// use anyhow::anyhow;
119    /// use shell_candy::{ShellTask, ShellTaskLog, ShellTaskOutput, ShellTaskBehavior};
120    ///
121    /// fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
122    ///     let result = ShellTask::new("rustc --version")?.run(|line| {
123    ///         match line {
124    ///             ShellTaskLog::Stderr(_) => {
125    ///                 ShellTaskBehavior::Passthrough
126    ///             },
127    ///             ShellTaskLog::Stdout(message) => {
128    ///                 eprintln!("{}", &message);
129    ///                 ShellTaskBehavior::EarlyReturn(Ok(message))
130    ///             }
131    ///         }
132    ///     })?;
133    ///     assert!(matches!(result, ShellTaskOutput::EarlyReturn { .. }));
134    ///     Ok(())
135    /// }
136    /// ```
137    ///
138    /// If your `log_handler` returns [`ShellTaskBehavior::Passthrough`] for
139    /// the entire lifecycle of the task, [`ShellTask::run`] will return [`ShellTaskOutput::CompleteOutput`].
140    ///
141    /// # Example
142    ///
143    /// ```
144    /// use anyhow::anyhow;
145    /// use shell_candy::{ShellTask, ShellTaskLog, ShellTaskOutput, ShellTaskBehavior};
146    ///
147    /// fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
148    ///     let result = ShellTask::new("rustc --version")?.run(|line| {
149    ///         match line {
150    ///             ShellTaskLog::Stderr(message) | ShellTaskLog::Stdout(message) => {
151    ///                 eprintln!("info: {}", &message);
152    ///                 ShellTaskBehavior::<()>::Passthrough
153    ///             }
154    ///         }
155    ///     })?;
156    ///     assert!(matches!(result, ShellTaskOutput::CompleteOutput { .. }));
157    ///     Ok(())
158    /// }
159    /// ```
160    pub fn run<F, T>(&self, log_handler: F) -> Result<ShellTaskOutput<T>>
161    where
162        F: Fn(ShellTaskLog) -> ShellTaskBehavior<T> + Send + Sync + 'static,
163        T: Send + Sync + 'static,
164    {
165        let log_drain: Arc<Mutex<Vec<ShellTaskLog>>> = Arc::new(Mutex::new(Vec::new()));
166        let log_drainer = log_drain.clone();
167        let log_drain_filler = log_drain.clone();
168        let log_receiver = self.log_receiver.clone();
169        let full_command = self.full_command.to_string();
170
171        let maybe_result = Arc::new(Mutex::new(None));
172        let early_terminator = maybe_result.clone();
173
174        let collected_stdout_lines = Arc::new(Mutex::new(Vec::new()));
175        let collected_stderr_lines = Arc::new(Mutex::new(Vec::new()));
176        let stdout_collector = collected_stdout_lines.clone();
177        let stderr_collector = collected_stderr_lines.clone();
178
179        rayon::spawn(move || {
180            while let Ok(line) = log_receiver.recv() {
181                match &line {
182                    ShellTaskLog::Stderr(stderr) => {
183                        if let Ok(mut stderr_lines) = stderr_collector.clone().lock() {
184                            stderr_lines.push(stderr.to_string())
185                        }
186                    }
187                    ShellTaskLog::Stdout(stdout) => {
188                        if let Ok(mut stdout_lines) = stdout_collector.clone().lock() {
189                            stdout_lines.push(stdout.to_string())
190                        }
191                    }
192                }
193
194                if let Ok(mut log_decrementer) = log_drainer.clone().lock() {
195                    if let Some(stderr_pos) = log_decrementer
196                        .iter()
197                        .position(|e| matches!(e, ShellTaskLog::Stderr(_)))
198                    {
199                        log_decrementer.remove(stderr_pos);
200                    } else if let Some(stdout_pos) = log_decrementer
201                        .iter()
202                        .position(|e| matches!(e, ShellTaskLog::Stdout(_)))
203                    {
204                        log_decrementer.remove(stdout_pos);
205                    }
206                    match (log_handler)(line) {
207                        ShellTaskBehavior::EarlyReturn(early_return) => {
208                            if let Ok(mut maybe_result) = early_terminator.lock() {
209                                if maybe_result.is_none() {
210                                    *maybe_result = Some(early_return);
211                                    break;
212                                }
213                            }
214                        }
215                        ShellTaskBehavior::Passthrough => continue,
216                    }
217                } else if let Ok(mut maybe_result) = early_terminator.lock() {
218                    if maybe_result.is_none() {
219                        *maybe_result =
220                            Some(Err(Box::new(Error::PoisonedLog { task: full_command })));
221                        break;
222                    }
223                } else {
224                    continue;
225                }
226            }
227        });
228
229        let task = ShellTaskRunner::run(
230            self.get_command(),
231            self.full_command.to_string(),
232            self.log_sender.clone(),
233            log_drain_filler,
234        )?;
235
236        let output = task
237            .child
238            .wait_with_output()
239            .map_err(|source| Error::CouldNotWait {
240                task: self.full_command.to_string(),
241                source,
242            })?;
243
244        // wait until the log drain is empty so we know they've all been processed
245        loop {
246            std::thread::sleep(Duration::from_millis(200));
247            match log_drain.try_lock() {
248                Ok(log_drain) => {
249                    if log_drain.is_empty() {
250                        break;
251                    } else {
252                        continue;
253                    }
254                }
255                _ => continue,
256            }
257        }
258
259        if output.status.success() {
260            let collected_stderr_lines = collected_stderr_lines.lock().unwrap().to_vec();
261            let collected_stdout_lines = collected_stdout_lines.lock().unwrap().to_vec();
262            if let Some(result) = maybe_result.clone().lock().unwrap().take() {
263                result
264                    .map(|t| ShellTaskOutput::EarlyReturn {
265                        stderr_lines: collected_stderr_lines,
266                        stdout_lines: collected_stdout_lines,
267                        return_value: t,
268                    })
269                    .map_err(|e| e.into())
270            } else {
271                Ok(ShellTaskOutput::CompleteOutput {
272                    status: output.status,
273                    stdout_lines: collected_stdout_lines,
274                    stderr_lines: collected_stderr_lines,
275                })
276            }
277        } else {
278            Err(Error::TaskFailure {
279                task: self.full_command.to_string(),
280                exit_status: output.status,
281            })
282        }
283    }
284}