1use std::{
6 collections::HashMap,
7 error::Error as StdError,
8 ffi::{OsStr, OsString},
9 io,
10 path::{Path, PathBuf},
11};
12
13use portable_pty::{native_pty_system, Child, CommandBuilder, PtyPair, PtySize};
14
15use crate::{
16 traits::{ConfigureCommand, ShellProcess, SpawnShell, SpawnedShell},
17 utils::is_recoverable_kill_error,
18};
19
20fn into_io_error(err: Box<dyn StdError + Send + Sync>) -> io::Error {
21 err.downcast::<io::Error>()
22 .map_or_else(io::Error::other, |err| *err)
23}
24
25#[cfg_attr(docsrs, doc(cfg(feature = "portable-pty")))]
46#[derive(Debug, Clone)]
47pub struct PtyCommand {
48 args: Vec<OsString>,
49 env: HashMap<OsString, OsString>,
50 current_dir: Option<PathBuf>,
51 pty_size: PtySize,
52}
53
54#[cfg(unix)]
55impl Default for PtyCommand {
56 fn default() -> Self {
57 Self::new("sh")
58 }
59}
60
61#[cfg(windows)]
62impl Default for PtyCommand {
63 fn default() -> Self {
64 let mut cmd = Self::new("cmd");
65 cmd.arg("/Q").arg("/K").arg("echo off && chcp 65001");
66 cmd
67 }
68}
69
70impl PtyCommand {
71 pub fn new(command: impl Into<OsString>) -> Self {
75 Self {
76 args: vec![command.into()],
77 env: HashMap::new(),
78 current_dir: None,
79 pty_size: PtySize {
80 rows: 19,
81 cols: 80,
82 pixel_width: 0,
83 pixel_height: 0,
84 },
85 }
86 }
87
88 pub fn with_size(&mut self, rows: u16, cols: u16) -> &mut Self {
90 self.pty_size.rows = rows;
91 self.pty_size.cols = cols;
92 self
93 }
94
95 pub fn arg(&mut self, arg: impl Into<OsString>) -> &mut Self {
97 self.args.push(arg.into());
98 self
99 }
100
101 fn to_command_builder(&self) -> CommandBuilder {
102 let mut builder = CommandBuilder::from_argv(self.args.clone());
103 for (name, value) in &self.env {
104 builder.env(name, value);
105 }
106 if let Some(current_dir) = &self.current_dir {
107 builder.cwd(current_dir);
108 }
109 builder
110 }
111}
112
113impl ConfigureCommand for PtyCommand {
114 fn current_dir(&mut self, dir: &Path) {
115 self.current_dir = Some(dir.to_owned());
116 }
117
118 fn env(&mut self, name: &str, value: &OsStr) {
119 self.env
120 .insert(OsStr::new(name).to_owned(), value.to_owned());
121 }
122}
123
124impl SpawnShell for PtyCommand {
125 type ShellProcess = PtyShell;
126 type Reader = Box<dyn io::Read + Send>;
127 type Writer = Box<dyn io::Write + Send>;
128
129 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
130 fn spawn_shell(&mut self) -> io::Result<SpawnedShell<Self>> {
131 let pty_system = native_pty_system();
132 let PtyPair { master, slave } = pty_system
133 .openpty(self.pty_size)
134 .map_err(|err| into_io_error(err.into()))?;
135 #[cfg(feature = "tracing")]
136 tracing::debug!("created PTY pair");
137
138 let child = slave
139 .spawn_command(self.to_command_builder())
140 .map_err(|err| into_io_error(err.into()))?;
141 #[cfg(feature = "tracing")]
142 tracing::debug!("spawned command into PTY");
143
144 let reader = master
145 .try_clone_reader()
146 .map_err(|err| into_io_error(err.into()))?;
147 let writer = master
148 .take_writer()
149 .map_err(|err| into_io_error(err.into()))?;
150 Ok(SpawnedShell {
151 shell: PtyShell { child },
152 reader,
153 writer,
154 })
155 }
156}
157
158#[cfg_attr(docsrs, doc(cfg(feature = "portable-pty")))]
160#[derive(Debug)]
161pub struct PtyShell {
162 child: Box<dyn Child + Send + Sync>,
163}
164
165impl ShellProcess for PtyShell {
166 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
167 fn check_is_alive(&mut self) -> io::Result<()> {
168 if let Some(exit_status) = self.child.try_wait()? {
169 let status_str = if exit_status.success() {
170 "zero"
171 } else {
172 "non-zero"
173 };
174 let message =
175 format!("Shell process has prematurely exited with {status_str} exit status");
176 Err(io::Error::new(io::ErrorKind::BrokenPipe, message))
177 } else {
178 Ok(())
179 }
180 }
181
182 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
183 fn terminate(mut self) -> io::Result<()> {
184 if self.child.try_wait()?.is_none() {
185 self.child.kill().or_else(|err| {
186 if is_recoverable_kill_error(&err) {
187 Ok(())
189 } else {
190 Err(err)
191 }
192 })?;
193 }
194 Ok(())
195 }
196
197 fn is_echoing(&self) -> bool {
198 true
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use std::{
205 io::{Read, Write},
206 thread,
207 time::Duration,
208 };
209
210 use super::*;
211 use crate::{ShellOptions, Transcript, UserInput};
212
213 #[test]
214 fn pty_trait_implementation() -> anyhow::Result<()> {
215 let mut pty_command = PtyCommand::default();
216 let mut spawned = pty_command.spawn_shell()?;
217
218 thread::sleep(Duration::from_millis(100));
219 spawned.shell.check_is_alive()?;
220
221 writeln!(spawned.writer, "echo Hello")?;
222 thread::sleep(Duration::from_millis(100));
223 spawned.shell.check_is_alive()?;
224
225 drop(spawned.writer); thread::sleep(Duration::from_millis(100));
227
228 spawned.shell.terminate()?;
229 let mut buffer = String::new();
230 spawned.reader.read_to_string(&mut buffer)?;
231
232 assert!(buffer.contains("Hello"), "Unexpected buffer: {buffer:?}");
233 Ok(())
234 }
235
236 #[test]
237 fn creating_transcript_with_pty() -> anyhow::Result<()> {
238 let mut options = ShellOptions::new(PtyCommand::default());
239 let inputs = vec![
240 UserInput::command("echo hello"),
241 UserInput::command("echo foo && echo bar >&2"),
242 ];
243 let transcript = Transcript::from_inputs(&mut options, inputs)?;
244
245 assert_eq!(transcript.interactions().len(), 2);
246
247 {
248 let interaction = &transcript.interactions()[0];
249 assert_eq!(interaction.input().text, "echo hello");
250 let output = interaction.output().as_ref();
251 assert_eq!(output.trim(), "hello");
252 }
253
254 let interaction = &transcript.interactions()[1];
255 assert_eq!(interaction.input().text, "echo foo && echo bar >&2");
256 let output = interaction.output().as_ref();
257 assert_eq!(
258 output.split_whitespace().collect::<Vec<_>>(),
259 ["foo", "bar"]
260 );
261 Ok(())
262 }
263
264 #[cfg(unix)]
265 #[test]
266 fn pty_transcript_with_multiline_input() -> anyhow::Result<()> {
267 let mut options = ShellOptions::new(PtyCommand::default());
268 let inputs = vec![UserInput::command("echo \\\nhello")];
269 let transcript = Transcript::from_inputs(&mut options, inputs)?;
270
271 assert_eq!(transcript.interactions().len(), 1);
272 let interaction = &transcript.interactions()[0];
273 let output = interaction.output().as_ref();
274 assert_eq!(output.trim(), "hello");
275 Ok(())
276 }
277}