shadow_terminal/
pty.rs

1//! Creates a PTY in an OS subprocess and sends and recieves bytes to/from it over channels.
2//!
3//! It doesn't actually maintain a visual representation, that requires the [`Wezterm`] terminal
4//! to parse the PTY's output, see: [`ShadowTerminal`].
5
6use std::{ffi::OsString, io::Read as _};
7
8use snafu::{OptionExt as _, ResultExt as _};
9use tokio::sync::mpsc;
10use tracing::Instrument as _;
11
12/// A single payload from the PTY output stream. This value comes from the `portable_pty` crate and
13/// doesn't seem adjustable.
14pub type BytesFromPTY = [u8; 4096];
15/// A single payload from the user's input stream (or sometimes internal input).
16pub type BytesFromSTDIN = [u8; 128];
17
18/// This is the PTY process that replaces the user's current TTY
19#[non_exhaustive]
20pub(crate) struct PTY {
21    /// PTY starting command
22    pub command: Vec<OsString>,
23    /// PTY width
24    pub width: u16,
25    /// PTY height
26    pub height: u16,
27    /// Send side of channel to send control messages like; shutdown and resize.
28    pub control_tx: tokio::sync::broadcast::Sender<crate::Protocol>,
29    /// Send side of channel sending updates from the PTY process
30    pub output_tx: tokio::sync::mpsc::Sender<crate::pty::BytesFromPTY>,
31}
32
33impl PTY {
34    /// Function just to isolate the PTY setup
35    fn setup_pty(&self) -> Result<portable_pty::PtyPair, crate::errors::PTYError> {
36        tracing::debug!("Setting up PTY");
37        let pty_system = portable_pty::native_pty_system();
38        let pair = pty_system
39            .openpty(Self::pty_size(self.width, self.height))
40            .with_whatever_context(|_| "Error opening PTY")?;
41
42        tracing::debug!("Launching `{:?}` on PTY", self.command);
43        let mut cmd = portable_pty::CommandBuilder::from_argv(self.command.clone());
44        cmd.cwd(
45            std::env::current_dir()
46                .with_whatever_context(|_| "Couldn't get user's current directory")?,
47        );
48        let spawn = pair
49            .slave
50            .spawn_command(cmd)
51            .with_whatever_context(|_| "Error spawning PTY command")?;
52        let killer = spawn.clone_killer();
53        Self::wait_for_pty_end(self.control_tx.clone(), spawn);
54        Self::kill_on_protocol_end(self.control_tx.subscribe(), killer);
55
56        tracing::trace!("Returning PTY pair");
57        Ok(pair)
58    }
59
60    /// The PTY crate is not async, so here we're basically just listening to the PTY to be able to
61    /// broadcast its output on an async channel.
62    fn pty_reader_loop(
63        pty_reader: std::boxed::Box<dyn std::io::Read + std::marker::Send>,
64        pty_reader_tx: mpsc::Sender<BytesFromPTY>,
65    ) -> tokio::task::JoinHandle<()> {
66        tokio::task::spawn_blocking(move || {
67            let mut reader = std::io::BufReader::new(pty_reader);
68            loop {
69                let mut buffer: BytesFromPTY = [0; 4096];
70
71                let now = std::time::Instant::now();
72                let read_result = reader.read(&mut buffer);
73                let elapsed = now.elapsed();
74
75                match read_result {
76                    Ok(0) => {
77                        tracing::debug!("PTY reader loop received 0 bytes, exiting...");
78                        break;
79                    }
80                    Ok(n) => {
81                        tracing::trace!(
82                            "Read {} PTY bytes. Time since last output {:?}",
83                            n,
84                            elapsed
85                        );
86                        let send_result = pty_reader_tx.blocking_send(buffer);
87                        if let Err(error) = send_result {
88                            tracing::error!("Broadcasting PTY output: {error:?}");
89                            break;
90                        }
91                    }
92                    Err(error) => tracing::error!("PTY reader: {error:?}"),
93                }
94            }
95            tracing::trace!("Leaving PTY reader loop");
96        })
97    }
98
99    /// A dedicated loop to listen for the official PTY end event.
100    fn wait_for_pty_end(
101        protocol_out: tokio::sync::broadcast::Sender<crate::Protocol>,
102        mut spawn: Box<dyn portable_pty::Child + Send + Sync>,
103    ) {
104        tokio::task::spawn_blocking(move || {
105            tracing::debug!("Starting to wait for PTY end");
106            let waiter_result = spawn.wait();
107            if let Err(error) = waiter_result {
108                tracing::error!("Waiting for PTY: {error:?}");
109            }
110
111            // A crude hack to make sure that early-exiting commands still have a chance to
112            // successfully send their output.
113            std::thread::sleep(std::time::Duration::from_millis(10));
114
115            let sender_result = protocol_out.send(crate::Protocol::End);
116            if let Err(error) = sender_result {
117                tracing::error!("Sending `Protocol::End` after: {error:?} ");
118            }
119            tracing::info!("PTY ended by its own accord");
120        });
121    }
122
123    /// Listen for the `End` message from the Tattoy protocol channel and then kill the PTY.
124    fn kill_on_protocol_end(
125        mut protocol_in: tokio::sync::broadcast::Receiver<crate::Protocol>,
126        mut spawn: Box<dyn portable_pty::ChildKiller + Send + Sync>,
127    ) {
128        let current_span = tracing::Span::current();
129        tokio::spawn(
130            async move {
131                tracing::debug!("Starting loop for PTY spawn to receive protocol messages");
132                loop {
133                    match protocol_in.recv().await {
134                        Ok(message) => {
135                            if matches!(message, crate::Protocol::End)  {
136                                tracing::debug!("PTY received Tattoy message {message:?}");
137                                let result = spawn.kill();
138                                if let Err(error) = result {
139                                    // This is the error when the PTY naturally ends. Is there a better way to
140                                    // match?
141                                    let pty_exit = "No such process";
142                                    if error.to_string().contains(pty_exit) {
143                                        tracing::debug!("Tried killing PTY that was already gone.");
144                                        break;
145                                    }
146
147                                    tracing::error!("Couldn't kill PTY: {error:?}");
148                                    // TODO: maybe we want to force exit here?
149                                }
150
151                                tracing::debug!(
152                                    "`kill()` (which includes OS kill signals) sent to PTY spawn process"
153                                );
154                                break;
155                            }
156                        }
157                        Err(error) => {
158                            tracing::error!("Reading protocol from PTY loop: {error:?}");
159                        }
160                    }
161                }
162                tracing::debug!("Leaving spawn shutdown listener loop.");
163            }
164            .instrument(current_span),
165        );
166    }
167
168    /// Start the PTY
169    pub(crate) async fn run(
170        self,
171        user_input_rx: mpsc::Receiver<BytesFromSTDIN>,
172        internal_input_rx: mpsc::Receiver<BytesFromSTDIN>,
173    ) -> Result<(), crate::errors::PTYError> {
174        let (pty_reader_tx, mut pty_reader_rx) = tokio::sync::mpsc::channel(1);
175
176        // It's important that we subscribe now, as that is what starts the backlog of protocol
177        // messages. It's possible that messages are sent during PTY startup and we don't want to
178        // miss any of those messages later when we finally start the listening loop.
179        let mut protocol_for_main_loop = self.control_tx.subscribe();
180
181        let pty_pair = self.setup_pty()?;
182        let pty_writer = pty_pair
183            .master
184            .take_writer()
185            .with_whatever_context(|err| format!("Getting PTY writer: {err:?}"))?;
186        let pty_reader = pty_pair
187            .master
188            .try_clone_reader()
189            .with_whatever_context(|err| format!("Getting PTY reader: {err:?}"))?;
190
191        Self::pty_reader_loop(pty_reader, pty_reader_tx);
192
193        // We have to drop the slave so that we don't hang on it when we exit.
194        drop(pty_pair.slave);
195
196        // TODO: should we be handling any errors in here?
197        let protocol_for_input_loop = self.control_tx.subscribe();
198        let current_span = tracing::Span::current();
199        tokio::spawn(async move {
200            let result = Self::forward_input(
201                user_input_rx,
202                internal_input_rx,
203                pty_writer,
204                pty_pair.master,
205                protocol_for_input_loop,
206            )
207            .instrument(current_span)
208            .await;
209            if let Err(err) = result {
210                tracing::error!("Writing to PTY stream: {err}");
211            }
212        });
213
214        tracing::debug!("Starting PTY reader loop");
215        #[expect(
216            clippy::integer_division_remainder_used,
217            reason = "`tokio::select!` generates this."
218        )]
219        loop {
220            tokio::select! {
221                result = self.read_stream(&mut pty_reader_rx) => {
222                    if let Err(error) = result {
223                        // TODO: The error should be bubbled, and logged centrally
224                        tracing::error!("{error:?}");
225                        snafu::whatever!("{error:?}");
226                    }
227                }
228                result = protocol_for_main_loop.recv() => {
229                    match result {
230                        Ok(message) => {
231                            if matches!(message, crate::Protocol::End) {
232                                break;
233                            }
234                        }
235                        Err(err) => {
236                            // TODO: The error should be bubbled, and logged centrally
237                            tracing::error!("{err:?}");
238                            snafu::whatever!("{err:?}");
239                        },
240
241                    }
242                }
243
244            }
245        }
246
247        tracing::debug!("PTY reader loop finished");
248        Ok(())
249    }
250
251    /// Read bytes from the underlying PTY sub process and forward them to the Shadow Terminal.
252    async fn read_stream(
253        &self,
254        pty_reader_rx: &mut mpsc::Receiver<BytesFromPTY>,
255    ) -> Result<(), crate::errors::PTYError> {
256        let Some(bytes) = pty_reader_rx.recv().await else {
257            return Ok(());
258        };
259
260        let result = self.output_tx.send(bytes).await;
261        if let Err(err) = result {
262            tracing::error!("Sending bytes on PTY output channel: {err}");
263        }
264
265        let output = String::from_utf8_lossy(&bytes)
266            .to_string()
267            .replace('\x1b', "^");
268        tracing::trace!("Sent PTY output, sample:\n{:.500}...", output);
269
270        Ok(())
271    }
272
273    /// Forward channel bytes from the user's input to the virtual PTY
274    async fn forward_input(
275        mut user_input: mpsc::Receiver<BytesFromSTDIN>,
276        mut internal_input: mpsc::Receiver<BytesFromSTDIN>,
277        mut pty_writer: std::boxed::Box<dyn std::io::Write + std::marker::Send>,
278        pty_master: std::boxed::Box<(dyn portable_pty::MasterPty + std::marker::Send + 'static)>,
279        mut protocol: tokio::sync::broadcast::Receiver<crate::Protocol>,
280    ) -> Result<(), crate::errors::PTYError> {
281        tracing::debug!("Starting `forward_input` loop");
282
283        #[expect(
284            clippy::integer_division_remainder_used,
285            reason = "This is generated by the `tokio::select!`"
286        )]
287        loop {
288            tokio::select! {
289                message = protocol.recv() => {
290                    Self::handle_protocol_message_for_input_loop(&message, &pty_master)?;
291                    if matches!(message, Ok(crate::Protocol::End)) {
292                        break;
293                    }
294                }
295                Some(some_bytes) = user_input.recv() => {
296                    Self::handle_input_bytes(some_bytes, &mut pty_writer)?;
297                }
298                Some(some_bytes) = internal_input.recv() => {
299                    Self::handle_input_bytes(some_bytes, &mut pty_writer)?;
300                }
301            }
302        }
303
304        tracing::debug!("`forward_input` loop finished");
305        Ok(())
306    }
307
308    /// Handle a message from the Tattoy protocol broadcast channel.
309    fn handle_protocol_message_for_input_loop(
310        message: &std::result::Result<crate::Protocol, tokio::sync::broadcast::error::RecvError>,
311        pty_master: &std::boxed::Box<(dyn portable_pty::MasterPty + std::marker::Send + 'static)>,
312    ) -> Result<(), crate::errors::PTYError> {
313        match message {
314            Ok(crate::Protocol::End) => {
315                tracing::trace!("PTY input forwarder task received {message:?}");
316                return Ok(());
317            }
318            Ok(crate::Protocol::Resize { width, height }) => {
319                tracing::debug!("Resize event received on PTY input loop {message:?}");
320
321                let result = pty_master.resize(Self::pty_size(*width, *height));
322                if result.is_err() {
323                    tracing::error!("Couldn't resize underlying PTY subprocesss: {result:?}");
324                }
325            }
326            Ok(_) => (),
327            Err(err) => snafu::whatever!("{err:?}"),
328        }
329
330        Ok(())
331    }
332
333    /// Handle input from end user.
334    fn handle_input_bytes(
335        bytes: BytesFromSTDIN,
336        pty_stdin: &mut std::boxed::Box<dyn std::io::Write + std::marker::Send>,
337    ) -> Result<(), crate::errors::PTYError> {
338        tracing::trace!(
339            "Forwarding input to PTY: '{}'",
340            String::from_utf8_lossy(&bytes).replace('\n', "\\n")
341        );
342
343        let maybe_size = bytes.iter().position(|byte| byte == &0);
344        let size = maybe_size.unwrap_or(128);
345        let byte_slice = bytes.get(0..size).with_whatever_context(|| {
346            "Couldn't get slice of input payload. Should be impossible."
347        })?;
348
349        pty_stdin
350            .write_all(byte_slice)
351            .with_whatever_context(|err| {
352                format!("`handle_input_bytes()`: couldn't write bytes into PTY's STDIN: {err:?}")
353            })?;
354        pty_stdin
355            .flush()
356            .with_whatever_context(|err| format!("Couldn't flush STDIN stream to PTY: {err:?}"))?;
357
358        Ok(())
359    }
360
361    /// Just a little central place to build the `PtySize` struct consistently.
362    const fn pty_size(width: u16, height: u16) -> portable_pty::PtySize {
363        portable_pty::PtySize {
364            cols: width,
365            rows: height,
366            // Not all systems support pixel_width, pixel_height,
367            // but it is good practice to set it to something
368            // that matches the size of the selected font.
369            pixel_width: 0,
370            pixel_height: 0,
371        }
372    }
373
374    /// Insert bytes into a buffer.
375    pub(crate) fn add_bytes_to_buffer(
376        buffer: &mut BytesFromSTDIN,
377        bytes: &[u8],
378    ) -> Result<(), crate::errors::PTYError> {
379        if bytes.len() > buffer.len() {
380            snafu::whatever!(
381                "Bytes ({}) to add to buffer are more than the buffer size ({}).",
382                bytes.len(),
383                buffer.len()
384            );
385        }
386        for (i, chunk_byte) in bytes.iter().enumerate() {
387            let buffer_byte = buffer
388                .get_mut(i)
389                .with_whatever_context(|| "Couldn't get byte from buffer")?;
390            *buffer_byte = *chunk_byte;
391        }
392
393        Ok(())
394    }
395}
396
397impl Drop for PTY {
398    fn drop(&mut self) {
399        tracing::debug!("PTY dropped, broadcasting `End` signal.");
400
401        let result: Result<_, crate::errors::PTYError> = self
402            .control_tx
403            .send(crate::Protocol::End)
404            .with_whatever_context(|err| {
405                format!("Couldn't send shutdown signal after PTY finished: {err:?}")
406            });
407
408        if let Err(err) = result {
409            tracing::error!("{err:?}");
410        }
411    }
412}
413
414#[cfg(test)]
415#[expect(clippy::print_stderr, reason = "Tests aren't so strict")]
416mod test {
417    use super::*;
418
419    fn run(
420        command: Vec<OsString>,
421    ) -> (
422        tokio::task::JoinHandle<std::string::String>,
423        mpsc::Sender<BytesFromSTDIN>,
424    ) {
425        // TODO: Think about a convenient way to enable this whenever only a single test is ran
426        // setup_logging().unwrap();
427
428        let (pty_output_tx, mut pty_output_rx) = mpsc::channel::<BytesFromPTY>(8);
429        let (pty_input_tx, pty_input_rx) = mpsc::channel::<BytesFromSTDIN>(1);
430        let (_, internal_input_rx) = mpsc::channel::<BytesFromSTDIN>(8);
431        let (protocol_tx, _) = tokio::sync::broadcast::channel(16);
432
433        let output_task = tokio::spawn(async move {
434            tracing::debug!("TEST: Output listener loop starting...");
435            let mut result: Vec<u8> = vec![];
436
437            // TODO: don't just rely on test commands sending an `exit` to allow this loop to
438            // finish.
439            while let Some(bytes) = pty_output_rx.recv().await {
440                result.extend(bytes.iter().copied());
441            }
442
443            let output = String::from_utf8_lossy(&result).into_owned();
444            tracing::debug!("TEST: `interactive()` output: {output:?}");
445            output
446        });
447
448        tokio::spawn(async move {
449            tracing::debug!("TEST: PTY.run() starting...");
450            let pty = PTY {
451                command,
452                width: 10,
453                height: 10,
454                output_tx: pty_output_tx,
455                control_tx: protocol_tx.clone(),
456            };
457            let result = pty.run(pty_input_rx, internal_input_rx).await;
458            if let Err(err) = result {
459                tracing::warn!("PTY (for tests) handle: {err:?}");
460            }
461            tracing::debug!("Test PTY.run() done");
462        });
463
464        tracing::debug!("TEST: Leaving run helper...");
465        (output_task, pty_input_tx)
466    }
467
468    /// TODO: Powershell isn't displaying emoji: 🌍
469    fn cat_earth_command() -> String {
470        let cat_command = "cat";
471        let path = crate::tests::helpers::workspace_dir()
472            .join("shadow-terminal")
473            .join("src")
474            .join("tests")
475            .join("cat_me.txt");
476
477        #[cfg(not(target_os = "windows"))]
478        let sleep = "&& sleep 0.5";
479        #[cfg(target_os = "windows")]
480        let sleep = "; Start-Sleep -Milliseconds 5";
481
482        format!("{cat_command} {} {sleep}", path.display())
483    }
484
485    fn stdin_bytes(input: &str) -> BytesFromSTDIN {
486        let mut buffer: BytesFromSTDIN = [0; 128];
487        #[expect(
488            clippy::indexing_slicing,
489            reason = "How do I do a range slice with []?"
490        )]
491        buffer[..input.len()].copy_from_slice(input.as_bytes());
492        buffer
493    }
494
495    #[tokio::test(flavor = "multi_thread")]
496    async fn basic_output() {
497        let mut command = crate::tests::helpers::get_canonical_shell();
498
499        #[cfg(not(target_os = "windows"))]
500        command.push("-c".into());
501        #[cfg(target_os = "windows")]
502        command.push("-Command".into());
503
504        command.push(cat_earth_command().into());
505
506        let (output_task, _) = run(command);
507        let result = output_task.await.unwrap();
508        eprintln!("{result}");
509
510        assert!(result.contains("earth"));
511    }
512
513    #[cfg(not(target_os = "windows"))]
514    #[tokio::test(flavor = "multi_thread")]
515    async fn interactive() {
516        let (output_task, input_channel) = run(crate::tests::helpers::get_canonical_shell());
517        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
518
519        #[cfg(not(target_os = "windows"))]
520        let exit = "&& exit";
521        #[cfg(target_os = "windows")]
522        let exit = "; exit";
523        let command = format!("{} {exit}\n", cat_earth_command());
524
525        input_channel
526            .send(stdin_bytes(command.as_ref()))
527            .await
528            .unwrap();
529        tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
530        let result = output_task.await.unwrap();
531        eprintln!("{result}");
532
533        assert!(result.contains("earth"));
534    }
535}