1use std::{
4 convert::Infallible,
5 env, error,
6 ffi::OsStr,
7 fmt, io,
8 path::{Path, PathBuf},
9 process::Command,
10 time::Duration,
11};
12
13mod standard;
14mod transcript_impl;
15
16pub use self::standard::StdShell;
17use crate::{
18 traits::{ConfigureCommand, Echoing, SpawnShell, SpawnedShell},
19 Captured, ExitStatus,
20};
21
22type StatusCheckerFn = dyn Fn(&Captured) -> Option<ExitStatus>;
23
24pub(crate) struct StatusCheck {
25 command: String,
26 response_checker: Box<StatusCheckerFn>,
27}
28
29impl fmt::Debug for StatusCheck {
30 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
31 formatter
32 .debug_struct("StatusCheck")
33 .field("command", &self.command)
34 .finish_non_exhaustive()
35 }
36}
37
38impl StatusCheck {
39 pub fn command(&self) -> &str {
40 &self.command
41 }
42
43 pub fn check(&self, response: &Captured) -> Option<ExitStatus> {
44 (self.response_checker)(response)
45 }
46}
47
48pub struct ShellOptions<Cmd = Command> {
57 command: Cmd,
58 path_additions: Vec<PathBuf>,
59 io_timeout: Duration,
60 init_timeout: Duration,
61 init_commands: Vec<String>,
62 line_decoder: Box<dyn FnMut(Vec<u8>) -> io::Result<String>>,
63 status_check: Option<StatusCheck>,
64}
65
66impl<Cmd: fmt::Debug> fmt::Debug for ShellOptions<Cmd> {
67 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
68 formatter
69 .debug_struct("ShellOptions")
70 .field("command", &self.command)
71 .field("path_additions", &self.path_additions)
72 .field("io_timeout", &self.io_timeout)
73 .field("init_timeout", &self.init_timeout)
74 .field("init_commands", &self.init_commands)
75 .field("status_check", &self.status_check)
76 .finish_non_exhaustive()
77 }
78}
79
80#[cfg(any(unix, windows))]
81impl Default for ShellOptions {
82 fn default() -> Self {
83 Self::new(Self::default_shell())
84 }
85}
86
87impl<Cmd: ConfigureCommand> From<Cmd> for ShellOptions<Cmd> {
88 fn from(command: Cmd) -> Self {
89 Self::new(command)
90 }
91}
92
93impl<Cmd: ConfigureCommand> ShellOptions<Cmd> {
94 #[cfg(unix)]
95 fn default_shell() -> Command {
96 Command::new("sh")
97 }
98
99 #[cfg(windows)]
100 fn default_shell() -> Command {
101 let mut command = Command::new("cmd");
102 command.arg("/Q").arg("/K").arg("echo off && chcp 65001");
104 command
105 }
106
107 pub fn new(command: Cmd) -> Self {
109 Self {
110 command,
111 path_additions: vec![],
112 io_timeout: Duration::from_millis(500),
113 init_timeout: Duration::from_millis(1_500),
114 init_commands: vec![],
115 line_decoder: Box::new(|line| {
116 String::from_utf8(line)
117 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.utf8_error()))
118 }),
119 status_check: None,
120 }
121 }
122
123 #[must_use]
125 pub fn echoing(self, is_echoing: bool) -> ShellOptions<Echoing<Cmd>> {
126 ShellOptions {
127 command: Echoing::new(self.command, is_echoing),
128 path_additions: self.path_additions,
129 io_timeout: self.io_timeout,
130 init_timeout: self.init_timeout,
131 init_commands: self.init_commands,
132 line_decoder: self.line_decoder,
133 status_check: self.status_check,
134 }
135 }
136
137 #[must_use]
139 pub fn with_current_dir(mut self, current_dir: impl AsRef<Path>) -> Self {
140 self.command.current_dir(current_dir.as_ref());
141 self
142 }
143
144 #[must_use]
153 pub fn with_io_timeout(mut self, io_timeout: Duration) -> Self {
154 self.io_timeout = io_timeout;
155 self
156 }
157
158 #[must_use]
164 pub fn with_init_timeout(mut self, init_timeout: Duration) -> Self {
165 self.init_timeout = init_timeout;
166 self
167 }
168
169 #[must_use]
172 pub fn with_init_command(mut self, command: impl Into<String>) -> Self {
173 self.init_commands.push(command.into());
174 self
175 }
176
177 #[must_use]
179 pub fn with_env(mut self, name: impl AsRef<str>, value: impl AsRef<OsStr>) -> Self {
180 self.command.env(name.as_ref(), value.as_ref());
181 self
182 }
183
184 #[must_use]
190 pub fn with_line_decoder<E, F>(mut self, mut mapper: F) -> Self
191 where
192 E: Into<Box<dyn error::Error + Send + Sync>>,
193 F: FnMut(Vec<u8>) -> Result<String, E> + 'static,
194 {
195 self.line_decoder = Box::new(move |line| {
196 mapper(line).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
197 });
198 self
199 }
200
201 #[must_use]
204 pub fn with_lossy_utf8_decoder(self) -> Self {
205 self.with_line_decoder::<Infallible, _>(|line| {
206 Ok(String::from_utf8_lossy(&line).into_owned())
207 })
208 }
209
210 #[must_use]
235 pub fn with_status_check<F>(mut self, command: impl Into<String>, checker: F) -> Self
236 where
237 F: Fn(&Captured) -> Option<ExitStatus> + 'static,
238 {
239 let command = command.into();
240 assert!(
241 command.bytes().all(|ch| ch != b'\n' && ch != b'\r'),
242 "`command` contains a newline character ('\\n' or '\\r')"
243 );
244
245 self.status_check = Some(StatusCheck {
246 command,
247 response_checker: Box::new(checker),
248 });
249 self
250 }
251
252 fn target_path() -> PathBuf {
256 let mut path = env::current_exe().expect("Cannot obtain path to the executing file");
257 path.pop();
258 if path.ends_with("deps") {
259 path.pop();
260 }
261 path
262 }
263
264 #[must_use]
273 pub fn with_cargo_path(mut self) -> Self {
274 let target_path = Self::target_path();
275 self.path_additions.push(target_path.join("examples"));
276 self.path_additions.push(target_path);
277 self
278 }
279
280 #[must_use]
284 pub fn with_additional_path(mut self, path: impl Into<PathBuf>) -> Self {
285 let path = path.into();
286 self.path_additions.push(path);
287 self
288 }
289}
290
291impl<Cmd: SpawnShell> ShellOptions<Cmd> {
292 #[cfg_attr(
293 feature = "tracing",
294 tracing::instrument(
295 level = "debug",
296 skip(self),
297 err,
298 fields(self.path_additions = ?self.path_additions)
299 )
300 )]
301 fn spawn_shell(&mut self) -> io::Result<SpawnedShell<Cmd>> {
302 #[cfg(unix)]
303 const PATH_SEPARATOR: &str = ":";
304 #[cfg(windows)]
305 const PATH_SEPARATOR: &str = ";";
306
307 if !self.path_additions.is_empty() {
308 let mut path_var = env::var_os("PATH").unwrap_or_default();
309 if !path_var.is_empty() {
310 path_var.push(PATH_SEPARATOR);
311 }
312 for (i, addition) in self.path_additions.iter().enumerate() {
313 path_var.push(addition);
314 if i + 1 < self.path_additions.len() {
315 path_var.push(PATH_SEPARATOR);
316 }
317 }
318 self.command.env("PATH", &path_var);
319 }
320 self.command.spawn_shell()
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::{Transcript, UserInput};
328
329 #[cfg(any(unix, windows))]
330 #[test]
331 fn creating_transcript_basics() -> anyhow::Result<()> {
332 let inputs = vec![
333 UserInput::command("echo hello"),
334 UserInput::command("echo foo && echo bar >&2"),
335 ];
336 let transcript = Transcript::from_inputs(&mut ShellOptions::default(), inputs)?;
337
338 assert_eq!(transcript.interactions().len(), 2);
339
340 {
341 let interaction = &transcript.interactions()[0];
342 assert_eq!(interaction.input().text, "echo hello");
343 let output = interaction.output().as_ref();
344 assert_eq!(output.trim(), "hello");
345 }
346
347 let interaction = &transcript.interactions()[1];
348 assert_eq!(interaction.input().text, "echo foo && echo bar >&2");
349 let output = interaction.output().as_ref();
350 assert_eq!(
351 output.split_whitespace().collect::<Vec<_>>(),
352 ["foo", "bar"]
353 );
354 Ok(())
355 }
356
357 #[cfg(unix)]
358 #[test]
359 fn transcript_with_multiline_input() -> anyhow::Result<()> {
360 let mut options = ShellOptions::default();
361 let inputs = vec![UserInput::command("echo \\\nhello")];
362 let transcript = Transcript::from_inputs(&mut options, inputs)?;
363
364 assert_eq!(transcript.interactions().len(), 1);
365 let interaction = &transcript.interactions()[0];
366 let output = interaction.output().as_ref();
367 assert_eq!(output.trim(), "hello");
368 Ok(())
369 }
370}