1use std::{
4 convert::Infallible,
5 env, error,
6 ffi::OsStr,
7 fmt, io,
8 path::{Path, PathBuf},
9 process::Command,
10 time::Duration,
11};
12
13mod standard;
14mod transcript_impl;
15
16pub use self::standard::StdShell;
17
18use crate::{
19 traits::{ConfigureCommand, Echoing, SpawnShell, SpawnedShell},
20 Captured, ExitStatus,
21};
22
23type StatusCheckerFn = dyn Fn(&Captured) -> Option<ExitStatus>;
24
25pub(crate) struct StatusCheck {
26 command: String,
27 response_checker: Box<StatusCheckerFn>,
28}
29
30impl fmt::Debug for StatusCheck {
31 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
32 formatter
33 .debug_struct("StatusCheck")
34 .field("command", &self.command)
35 .finish()
36 }
37}
38
39impl StatusCheck {
40 pub fn command(&self) -> &str {
41 &self.command
42 }
43
44 pub fn check(&self, response: &Captured) -> Option<ExitStatus> {
45 (self.response_checker)(response)
46 }
47}
48
49pub struct ShellOptions<Cmd = Command> {
58 command: Cmd,
59 path_additions: Vec<PathBuf>,
60 io_timeout: Duration,
61 init_timeout: Duration,
62 init_commands: Vec<String>,
63 line_decoder: Box<dyn FnMut(Vec<u8>) -> io::Result<String>>,
64 status_check: Option<StatusCheck>,
65}
66
67impl<Cmd: fmt::Debug> fmt::Debug for ShellOptions<Cmd> {
68 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
69 formatter
70 .debug_struct("ShellOptions")
71 .field("command", &self.command)
72 .field("path_additions", &self.path_additions)
73 .field("io_timeout", &self.io_timeout)
74 .field("init_timeout", &self.init_timeout)
75 .field("init_commands", &self.init_commands)
76 .field("status_check", &self.status_check)
77 .finish()
78 }
79}
80
81#[cfg(any(unix, windows))]
82impl Default for ShellOptions {
83 fn default() -> Self {
84 Self::new(Self::default_shell())
85 }
86}
87
88impl<Cmd: ConfigureCommand> From<Cmd> for ShellOptions<Cmd> {
89 fn from(command: Cmd) -> Self {
90 Self::new(command)
91 }
92}
93
94impl<Cmd: ConfigureCommand> ShellOptions<Cmd> {
95 #[cfg(unix)]
96 fn default_shell() -> Command {
97 Command::new("sh")
98 }
99
100 #[cfg(windows)]
101 fn default_shell() -> Command {
102 let mut command = Command::new("cmd");
103 command.arg("/Q").arg("/K").arg("echo off && chcp 65001");
105 command
106 }
107
108 pub fn new(command: Cmd) -> Self {
110 Self {
111 command,
112 path_additions: vec![],
113 io_timeout: Duration::from_millis(500),
114 init_timeout: Duration::from_millis(1_500),
115 init_commands: vec![],
116 line_decoder: Box::new(|line| {
117 String::from_utf8(line)
118 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.utf8_error()))
119 }),
120 status_check: None,
121 }
122 }
123
124 #[must_use]
126 pub fn echoing(self, is_echoing: bool) -> ShellOptions<Echoing<Cmd>> {
127 ShellOptions {
128 command: Echoing::new(self.command, is_echoing),
129 path_additions: self.path_additions,
130 io_timeout: self.io_timeout,
131 init_timeout: self.init_timeout,
132 init_commands: self.init_commands,
133 line_decoder: self.line_decoder,
134 status_check: self.status_check,
135 }
136 }
137
138 #[must_use]
140 pub fn with_current_dir(mut self, current_dir: impl AsRef<Path>) -> Self {
141 self.command.current_dir(current_dir.as_ref());
142 self
143 }
144
145 #[must_use]
154 pub fn with_io_timeout(mut self, io_timeout: Duration) -> Self {
155 self.io_timeout = io_timeout;
156 self
157 }
158
159 #[must_use]
165 pub fn with_init_timeout(mut self, init_timeout: Duration) -> Self {
166 self.init_timeout = init_timeout;
167 self
168 }
169
170 #[must_use]
173 pub fn with_init_command(mut self, command: impl Into<String>) -> Self {
174 self.init_commands.push(command.into());
175 self
176 }
177
178 #[must_use]
180 pub fn with_env(mut self, name: impl AsRef<str>, value: impl AsRef<OsStr>) -> Self {
181 self.command.env(name.as_ref(), value.as_ref());
182 self
183 }
184
185 #[must_use]
191 pub fn with_line_decoder<E, F>(mut self, mut mapper: F) -> Self
192 where
193 E: Into<Box<dyn error::Error + Send + Sync>>,
194 F: FnMut(Vec<u8>) -> Result<String, E> + 'static,
195 {
196 self.line_decoder = Box::new(move |line| {
197 mapper(line).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
198 });
199 self
200 }
201
202 #[must_use]
205 pub fn with_lossy_utf8_decoder(self) -> Self {
206 self.with_line_decoder::<Infallible, _>(|line| {
207 Ok(String::from_utf8_lossy(&line).into_owned())
208 })
209 }
210
211 #[must_use]
236 pub fn with_status_check<F>(mut self, command: impl Into<String>, checker: F) -> Self
237 where
238 F: Fn(&Captured) -> Option<ExitStatus> + 'static,
239 {
240 let command = command.into();
241 assert!(
242 command.bytes().all(|ch| ch != b'\n' && ch != b'\r'),
243 "`command` contains a newline character ('\\n' or '\\r')"
244 );
245
246 self.status_check = Some(StatusCheck {
247 command,
248 response_checker: Box::new(checker),
249 });
250 self
251 }
252
253 fn target_path() -> PathBuf {
257 let mut path = env::current_exe().expect("Cannot obtain path to the executing file");
258 path.pop();
259 if path.ends_with("deps") {
260 path.pop();
261 }
262 path
263 }
264
265 #[must_use]
274 pub fn with_cargo_path(mut self) -> Self {
275 let target_path = Self::target_path();
276 self.path_additions.push(target_path.join("examples"));
277 self.path_additions.push(target_path);
278 self
279 }
280
281 #[must_use]
285 pub fn with_additional_path(mut self, path: impl Into<PathBuf>) -> Self {
286 let path = path.into();
287 self.path_additions.push(path);
288 self
289 }
290}
291
292impl<Cmd: SpawnShell> ShellOptions<Cmd> {
293 #[cfg_attr(
294 feature = "tracing",
295 tracing::instrument(
296 level = "debug",
297 skip(self),
298 err,
299 fields(self.path_additions = ?self.path_additions)
300 )
301 )]
302 fn spawn_shell(&mut self) -> io::Result<SpawnedShell<Cmd>> {
303 #[cfg(unix)]
304 const PATH_SEPARATOR: &str = ":";
305 #[cfg(windows)]
306 const PATH_SEPARATOR: &str = ";";
307
308 if !self.path_additions.is_empty() {
309 let mut path_var = env::var_os("PATH").unwrap_or_default();
310 if !path_var.is_empty() {
311 path_var.push(PATH_SEPARATOR);
312 }
313 for (i, addition) in self.path_additions.iter().enumerate() {
314 path_var.push(addition);
315 if i + 1 < self.path_additions.len() {
316 path_var.push(PATH_SEPARATOR);
317 }
318 }
319 self.command.env("PATH", &path_var);
320 }
321 self.command.spawn_shell()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::{Transcript, UserInput};
329
330 #[cfg(any(unix, windows))]
331 #[test]
332 fn creating_transcript_basics() -> anyhow::Result<()> {
333 let inputs = vec![
334 UserInput::command("echo hello"),
335 UserInput::command("echo foo && echo bar >&2"),
336 ];
337 let transcript = Transcript::from_inputs(&mut ShellOptions::default(), inputs)?;
338
339 assert_eq!(transcript.interactions().len(), 2);
340
341 {
342 let interaction = &transcript.interactions()[0];
343 assert_eq!(interaction.input().text, "echo hello");
344 let output = interaction.output().as_ref();
345 assert_eq!(output.trim(), "hello");
346 }
347
348 let interaction = &transcript.interactions()[1];
349 assert_eq!(interaction.input().text, "echo foo && echo bar >&2");
350 let output = interaction.output().as_ref();
351 assert_eq!(
352 output.split_whitespace().collect::<Vec<_>>(),
353 ["foo", "bar"]
354 );
355 Ok(())
356 }
357
358 #[cfg(unix)]
359 #[test]
360 fn transcript_with_multiline_input() -> anyhow::Result<()> {
361 let mut options = ShellOptions::default();
362 let inputs = vec![UserInput::command("echo \\\nhello")];
363 let transcript = Transcript::from_inputs(&mut options, inputs)?;
364
365 assert_eq!(transcript.interactions().len(), 1);
366 let interaction = &transcript.interactions()[0];
367 let output = interaction.output().as_ref();
368 assert_eq!(output.trim(), "hello");
369 Ok(())
370 }
371}