1use std::{
4 io::{self, BufRead, BufReader, LineWriter, Read},
5 iter,
6 process::{Command, Stdio},
7 sync::mpsc,
8 thread,
9 time::Duration,
10};
11
12use super::ShellOptions;
13use crate::{
14 traits::{ShellProcess, SpawnShell, SpawnedShell},
15 Captured, Interaction, Transcript, UserInput,
16};
17
18#[derive(Debug)]
19struct Timeouts {
20 inner: iter::Chain<iter::Once<Duration>, iter::Repeat<Duration>>,
21}
22
23impl Timeouts {
24 fn new<Cmd: SpawnShell>(options: &ShellOptions<Cmd>) -> Self {
25 Self {
26 inner: iter::once(options.init_timeout + options.io_timeout)
27 .chain(iter::repeat(options.io_timeout)),
28 }
29 }
30
31 fn next(&mut self) -> Duration {
32 self.inner.next().unwrap() }
34}
35
36impl Transcript {
37 #[cfg(not(windows))]
38 #[cfg_attr(
39 feature = "tracing",
40 tracing::instrument(level = "debug", skip(writer), err)
41 )]
42 fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
43 writeln!(writer, "{line}")
44 }
45
46 #[cfg(windows)]
48 #[cfg_attr(
49 feature = "tracing",
50 tracing::instrument(level = "debug", skip(writer), err)
51 )]
52 fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
53 writeln!(writer, "{line}\r")
54 }
55
56 #[cfg_attr(
57 feature = "tracing",
58 tracing::instrument(level = "debug", skip(lines_recv), err)
59 )]
60 #[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
61 fn read_echo(
63 input_line: &str,
64 lines_recv: &mpsc::Receiver<Vec<u8>>,
65 io_timeout: Duration,
66 ) -> io::Result<()> {
67 if let Ok(line) = lines_recv.recv_timeout(io_timeout) {
68 #[cfg(feature = "tracing")]
69 tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
70 Ok(())
71 } else {
72 let err =
73 format!("could not read all input `{input_line}` back from an echoing terminal");
74 Err(io::Error::new(io::ErrorKind::BrokenPipe, err))
75 }
76 }
77
78 #[cfg_attr(
79 feature = "tracing",
80 tracing::instrument(level = "debug", skip_all, ret, err)
81 )]
82 fn read_output(
83 lines_recv: &mpsc::Receiver<Vec<u8>>,
84 mut timeouts: Timeouts,
85 line_decoder: &mut dyn FnMut(Vec<u8>) -> io::Result<String>,
86 ) -> io::Result<String> {
87 let mut output = String::new();
88
89 while let Ok(mut line) = lines_recv.recv_timeout(timeouts.next()) {
90 if line.last() == Some(&b'\r') {
91 line.pop();
93 }
94 #[cfg(feature = "tracing")]
95 tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
96
97 let mapped_line = line_decoder(line)?;
98 #[cfg(feature = "tracing")]
99 tracing::debug!(?mapped_line, "mapped received line");
100 output.push_str(&mapped_line);
101 output.push('\n');
102 }
103
104 if output.ends_with('\n') {
105 output.truncate(output.len() - 1);
106 }
107 Ok(output)
108 }
109
110 #[cfg_attr(
121 feature = "tracing",
122 tracing::instrument(
123 skip_all,
124 err,
125 fields(
126 options.io_timeout = ?options.io_timeout,
127 options.init_timeout = ?options.init_timeout,
128 options.path_additions = ?options.path_additions,
129 options.init_commands = ?options.init_commands
130 )
131 )
132 )]
133 pub fn from_inputs<Cmd: SpawnShell>(
134 options: &mut ShellOptions<Cmd>,
135 inputs: impl IntoIterator<Item = UserInput>,
136 ) -> io::Result<Self> {
137 let SpawnedShell {
138 mut shell,
139 reader,
140 writer,
141 } = options.spawn_shell()?;
142
143 let stdout = BufReader::new(reader);
144 let (out_lines_send, out_lines_recv) = mpsc::channel();
145 let io_handle = thread::spawn(move || {
146 #[cfg(feature = "tracing")]
147 let _entered = tracing::debug_span!("reader_thread").entered();
148
149 let mut lines = stdout.split(b'\n');
150 while let Some(Ok(line)) = lines.next() {
151 #[cfg(feature = "tracing")]
152 tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
153
154 if out_lines_send.send(line).is_err() {
155 #[cfg(feature = "tracing")]
156 tracing::debug!("receiver dropped, breaking reader loop");
157 break;
158 }
159 }
160 });
161
162 let mut stdin = LineWriter::new(writer);
163 Self::push_init_commands(options, &out_lines_recv, &mut shell, &mut stdin)?;
164
165 let mut transcript = Self::new();
166 for input in inputs {
167 let interaction =
168 Self::record_interaction(options, input, &out_lines_recv, &mut shell, &mut stdin)?;
169 transcript.interactions.push(interaction);
170 }
171
172 drop(stdin); thread::sleep(options.io_timeout / 4);
176
177 shell.terminate()?;
178 io_handle.join().ok(); Ok(transcript)
180 }
181
182 #[cfg_attr(
183 feature = "tracing",
184 tracing::instrument(
185 level = "debug",
186 skip_all,
187 err,
188 fields(options.init_commands = ?options.init_commands)
189 )
190 )]
191 fn push_init_commands<Cmd: SpawnShell>(
192 options: &ShellOptions<Cmd>,
193 lines_recv: &mpsc::Receiver<Vec<u8>>,
194 shell: &mut Cmd::ShellProcess,
195 stdin: &mut impl io::Write,
196 ) -> io::Result<()> {
197 let mut timeouts = Timeouts::new(options);
199 while lines_recv.recv_timeout(timeouts.next()).is_ok() {
200 }
202
203 for cmd in &options.init_commands {
205 Self::write_line(stdin, cmd)?;
206 if shell.is_echoing() {
207 Self::read_echo(cmd, lines_recv, options.io_timeout)?;
208 }
209
210 let mut timeouts = Timeouts::new(options);
212 while lines_recv.recv_timeout(timeouts.next()).is_ok() {
213 }
215 }
216 Ok(())
217 }
218
219 #[cfg_attr(
220 feature = "tracing",
221 tracing::instrument(level = "debug", skip(options, lines_recv, shell, stdin), ret, err)
222 )]
223 fn record_interaction<Cmd: SpawnShell>(
224 options: &mut ShellOptions<Cmd>,
225 input: UserInput,
226 lines_recv: &mpsc::Receiver<Vec<u8>>,
227 shell: &mut Cmd::ShellProcess,
228 stdin: &mut impl io::Write,
229 ) -> io::Result<Interaction> {
230 shell.check_is_alive()?;
233
234 let input_lines = input.text.split('\n');
235 for input_line in input_lines {
236 Self::write_line(stdin, input_line)?;
237 if shell.is_echoing() {
238 Self::read_echo(input_line, lines_recv, options.io_timeout)?;
239 }
240 }
241
242 let output = Self::read_output(
243 lines_recv,
244 Timeouts::new(options),
245 options.line_decoder.as_mut(),
246 )?;
247
248 let exit_status = if let Some(status_check) = &options.status_check {
249 let command = status_check.command();
250 Self::write_line(stdin, command)?;
251 if shell.is_echoing() {
252 Self::read_echo(command, lines_recv, options.io_timeout)?;
253 }
254 let response = Self::read_output(
255 lines_recv,
256 Timeouts::new(options),
257 options.line_decoder.as_mut(),
258 )?;
259 status_check.check(&Captured::from(response))
260 } else {
261 None
262 };
263
264 let mut interaction = Interaction::new(input, output);
265 interaction.exit_status = exit_status;
266 Ok(interaction)
267 }
268
269 #[cfg_attr(
279 feature = "tracing",
280 tracing::instrument(skip(self, input), err, fields(input.text = %input.text))
281 )]
282 pub fn capture_output(
283 &mut self,
284 input: UserInput,
285 command: &mut Command,
286 ) -> io::Result<&mut Self> {
287 let (mut pipe_reader, pipe_writer) = os_pipe::pipe()?;
288 #[cfg(feature = "tracing")]
289 tracing::debug!("created OS pipe");
290
291 let mut child = command
292 .stdin(Stdio::null())
293 .stdout(pipe_writer.try_clone()?)
294 .stderr(pipe_writer)
295 .spawn()?;
296 #[cfg(feature = "tracing")]
297 tracing::debug!("created child");
298
299 command.stdout(Stdio::null()).stderr(Stdio::null());
301
302 let mut output = vec![];
303 pipe_reader.read_to_end(&mut output)?;
304 child.wait()?;
305
306 let output = String::from_utf8(output)
307 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.utf8_error()))?;
308 #[cfg(feature = "tracing")]
309 tracing::debug!(?output, "read command output");
310
311 self.interactions.push(Interaction::new(input, output));
312 Ok(self)
313 }
314}