1use std::{ffi::OsString, io::Read as _};
7
8use snafu::{OptionExt as _, ResultExt as _};
9use tokio::sync::mpsc;
10use tracing::Instrument as _;
11
12pub type BytesFromPTY = [u8; 4096];
15pub type BytesFromSTDIN = [u8; 128];
17
18#[non_exhaustive]
20pub(crate) struct PTY {
21 pub command: Vec<OsString>,
23 pub width: u16,
25 pub height: u16,
27 pub control_tx: tokio::sync::broadcast::Sender<crate::Protocol>,
29 pub output_tx: tokio::sync::mpsc::Sender<crate::pty::BytesFromPTY>,
31}
32
33impl PTY {
34 fn setup_pty(&self) -> Result<portable_pty::PtyPair, crate::errors::PTYError> {
36 tracing::debug!("Setting up PTY");
37 let pty_system = portable_pty::native_pty_system();
38 let pair = pty_system
39 .openpty(Self::pty_size(self.width, self.height))
40 .with_whatever_context(|_| "Error opening PTY")?;
41
42 tracing::debug!("Launching `{:?}` on PTY", self.command);
43 let mut cmd = portable_pty::CommandBuilder::from_argv(self.command.clone());
44 cmd.cwd(
45 std::env::current_dir()
46 .with_whatever_context(|_| "Couldn't get user's current directory")?,
47 );
48 let spawn = pair
49 .slave
50 .spawn_command(cmd)
51 .with_whatever_context(|_| "Error spawning PTY command")?;
52 let killer = spawn.clone_killer();
53 Self::wait_for_pty_end(self.control_tx.clone(), spawn);
54 Self::kill_on_protocol_end(self.control_tx.subscribe(), killer);
55
56 tracing::trace!("Returning PTY pair");
57 Ok(pair)
58 }
59
60 fn pty_reader_loop(
63 pty_reader: std::boxed::Box<dyn std::io::Read + std::marker::Send>,
64 pty_reader_tx: mpsc::Sender<BytesFromPTY>,
65 ) -> tokio::task::JoinHandle<()> {
66 tokio::task::spawn_blocking(move || {
67 let mut reader = std::io::BufReader::new(pty_reader);
68 loop {
69 let mut buffer: BytesFromPTY = [0; 4096];
70
71 let now = std::time::Instant::now();
72 let read_result = reader.read(&mut buffer);
73 let elapsed = now.elapsed();
74
75 match read_result {
76 Ok(0) => {
77 tracing::debug!("PTY reader loop received 0 bytes, exiting...");
78 break;
79 }
80 Ok(n) => {
81 tracing::trace!(
82 "Read {} PTY bytes. Time since last output {:?}",
83 n,
84 elapsed
85 );
86 let send_result = pty_reader_tx.blocking_send(buffer);
87 if let Err(error) = send_result {
88 tracing::error!("Broadcasting PTY output: {error:?}");
89 break;
90 }
91 }
92 Err(error) => tracing::error!("PTY reader: {error:?}"),
93 }
94 }
95 tracing::trace!("Leaving PTY reader loop");
96 })
97 }
98
99 fn wait_for_pty_end(
101 protocol_out: tokio::sync::broadcast::Sender<crate::Protocol>,
102 mut spawn: Box<dyn portable_pty::Child + Send + Sync>,
103 ) {
104 tokio::task::spawn_blocking(move || {
105 tracing::debug!("Starting to wait for PTY end");
106 let waiter_result = spawn.wait();
107 if let Err(error) = waiter_result {
108 tracing::error!("Waiting for PTY: {error:?}");
109 }
110
111 std::thread::sleep(std::time::Duration::from_millis(10));
114
115 let sender_result = protocol_out.send(crate::Protocol::End);
116 if let Err(error) = sender_result {
117 tracing::error!("Sending `Protocol::End` after: {error:?} ");
118 }
119 tracing::info!("PTY ended by its own accord");
120 });
121 }
122
123 fn kill_on_protocol_end(
125 mut protocol_in: tokio::sync::broadcast::Receiver<crate::Protocol>,
126 mut spawn: Box<dyn portable_pty::ChildKiller + Send + Sync>,
127 ) {
128 let current_span = tracing::Span::current();
129 tokio::spawn(
130 async move {
131 tracing::debug!("Starting loop for PTY spawn to receive protocol messages");
132 loop {
133 match protocol_in.recv().await {
134 Ok(message) => {
135 if matches!(message, crate::Protocol::End) {
136 tracing::debug!("PTY received Tattoy message {message:?}");
137 let result = spawn.kill();
138 if let Err(error) = result {
139 let pty_exit = "No such process";
142 if error.to_string().contains(pty_exit) {
143 tracing::debug!("Tried killing PTY that was already gone.");
144 break;
145 }
146
147 tracing::error!("Couldn't kill PTY: {error:?}");
148 }
150
151 tracing::debug!(
152 "`kill()` (which includes OS kill signals) sent to PTY spawn process"
153 );
154 break;
155 }
156 }
157 Err(error) => {
158 tracing::error!("Reading protocol from PTY loop: {error:?}");
159 }
160 }
161 }
162 tracing::debug!("Leaving spawn shutdown listener loop.");
163 }
164 .instrument(current_span),
165 );
166 }
167
168 pub(crate) async fn run(
170 self,
171 user_input_rx: mpsc::Receiver<BytesFromSTDIN>,
172 internal_input_rx: mpsc::Receiver<BytesFromSTDIN>,
173 ) -> Result<(), crate::errors::PTYError> {
174 let (pty_reader_tx, mut pty_reader_rx) = tokio::sync::mpsc::channel(1);
175
176 let mut protocol_for_main_loop = self.control_tx.subscribe();
180
181 let pty_pair = self.setup_pty()?;
182 let pty_writer = pty_pair
183 .master
184 .take_writer()
185 .with_whatever_context(|err| format!("Getting PTY writer: {err:?}"))?;
186 let pty_reader = pty_pair
187 .master
188 .try_clone_reader()
189 .with_whatever_context(|err| format!("Getting PTY reader: {err:?}"))?;
190
191 Self::pty_reader_loop(pty_reader, pty_reader_tx);
192
193 drop(pty_pair.slave);
195
196 let protocol_for_input_loop = self.control_tx.subscribe();
198 let current_span = tracing::Span::current();
199 tokio::spawn(async move {
200 let result = Self::forward_input(
201 user_input_rx,
202 internal_input_rx,
203 pty_writer,
204 pty_pair.master,
205 protocol_for_input_loop,
206 )
207 .instrument(current_span)
208 .await;
209 if let Err(err) = result {
210 tracing::error!("Writing to PTY stream: {err}");
211 }
212 });
213
214 tracing::debug!("Starting PTY reader loop");
215 #[expect(
216 clippy::integer_division_remainder_used,
217 reason = "`tokio::select!` generates this."
218 )]
219 loop {
220 tokio::select! {
221 result = self.read_stream(&mut pty_reader_rx) => {
222 if let Err(error) = result {
223 tracing::error!("{error:?}");
225 snafu::whatever!("{error:?}");
226 }
227 }
228 result = protocol_for_main_loop.recv() => {
229 match result {
230 Ok(message) => {
231 if matches!(message, crate::Protocol::End) {
232 break;
233 }
234 }
235 Err(err) => {
236 tracing::error!("{err:?}");
238 snafu::whatever!("{err:?}");
239 },
240
241 }
242 }
243
244 }
245 }
246
247 tracing::debug!("PTY reader loop finished");
248 Ok(())
249 }
250
251 async fn read_stream(
253 &self,
254 pty_reader_rx: &mut mpsc::Receiver<BytesFromPTY>,
255 ) -> Result<(), crate::errors::PTYError> {
256 let Some(bytes) = pty_reader_rx.recv().await else {
257 return Ok(());
258 };
259
260 let result = self.output_tx.send(bytes).await;
261 if let Err(err) = result {
262 tracing::error!("Sending bytes on PTY output channel: {err}");
263 }
264
265 let output = String::from_utf8_lossy(&bytes)
266 .to_string()
267 .replace('\x1b', "^");
268 tracing::trace!("Sent PTY output, sample:\n{:.500}...", output);
269
270 Ok(())
271 }
272
273 async fn forward_input(
275 mut user_input: mpsc::Receiver<BytesFromSTDIN>,
276 mut internal_input: mpsc::Receiver<BytesFromSTDIN>,
277 mut pty_writer: std::boxed::Box<dyn std::io::Write + std::marker::Send>,
278 pty_master: std::boxed::Box<(dyn portable_pty::MasterPty + std::marker::Send + 'static)>,
279 mut protocol: tokio::sync::broadcast::Receiver<crate::Protocol>,
280 ) -> Result<(), crate::errors::PTYError> {
281 tracing::debug!("Starting `forward_input` loop");
282
283 #[expect(
284 clippy::integer_division_remainder_used,
285 reason = "This is generated by the `tokio::select!`"
286 )]
287 loop {
288 tokio::select! {
289 message = protocol.recv() => {
290 Self::handle_protocol_message_for_input_loop(&message, &pty_master)?;
291 if matches!(message, Ok(crate::Protocol::End)) {
292 break;
293 }
294 }
295 Some(some_bytes) = user_input.recv() => {
296 Self::handle_input_bytes(some_bytes, &mut pty_writer)?;
297 }
298 Some(some_bytes) = internal_input.recv() => {
299 Self::handle_input_bytes(some_bytes, &mut pty_writer)?;
300 }
301 }
302 }
303
304 tracing::debug!("`forward_input` loop finished");
305 Ok(())
306 }
307
308 fn handle_protocol_message_for_input_loop(
310 message: &std::result::Result<crate::Protocol, tokio::sync::broadcast::error::RecvError>,
311 pty_master: &std::boxed::Box<(dyn portable_pty::MasterPty + std::marker::Send + 'static)>,
312 ) -> Result<(), crate::errors::PTYError> {
313 match message {
314 Ok(crate::Protocol::End) => {
315 tracing::trace!("PTY input forwarder task received {message:?}");
316 return Ok(());
317 }
318 Ok(crate::Protocol::Resize { width, height }) => {
319 tracing::debug!("Resize event received on PTY input loop {message:?}");
320
321 let result = pty_master.resize(Self::pty_size(*width, *height));
322 if result.is_err() {
323 tracing::error!("Couldn't resize underlying PTY subprocesss: {result:?}");
324 }
325 }
326 Ok(_) => (),
327 Err(err) => snafu::whatever!("{err:?}"),
328 }
329
330 Ok(())
331 }
332
333 fn handle_input_bytes(
335 bytes: BytesFromSTDIN,
336 pty_stdin: &mut std::boxed::Box<dyn std::io::Write + std::marker::Send>,
337 ) -> Result<(), crate::errors::PTYError> {
338 tracing::trace!(
339 "Forwarding input to PTY: '{}'",
340 String::from_utf8_lossy(&bytes).replace('\n', "\\n")
341 );
342
343 let maybe_size = bytes.iter().position(|byte| byte == &0);
344 let size = maybe_size.unwrap_or(128);
345 let byte_slice = bytes.get(0..size).with_whatever_context(|| {
346 "Couldn't get slice of input payload. Should be impossible."
347 })?;
348
349 pty_stdin
350 .write_all(byte_slice)
351 .with_whatever_context(|err| {
352 format!("`handle_input_bytes()`: couldn't write bytes into PTY's STDIN: {err:?}")
353 })?;
354 pty_stdin
355 .flush()
356 .with_whatever_context(|err| format!("Couldn't flush STDIN stream to PTY: {err:?}"))?;
357
358 Ok(())
359 }
360
361 const fn pty_size(width: u16, height: u16) -> portable_pty::PtySize {
363 portable_pty::PtySize {
364 cols: width,
365 rows: height,
366 pixel_width: 0,
370 pixel_height: 0,
371 }
372 }
373
374 pub(crate) fn add_bytes_to_buffer(
376 buffer: &mut BytesFromSTDIN,
377 bytes: &[u8],
378 ) -> Result<(), crate::errors::PTYError> {
379 if bytes.len() > buffer.len() {
380 snafu::whatever!(
381 "Bytes ({}) to add to buffer are more than the buffer size ({}).",
382 bytes.len(),
383 buffer.len()
384 );
385 }
386 for (i, chunk_byte) in bytes.iter().enumerate() {
387 let buffer_byte = buffer
388 .get_mut(i)
389 .with_whatever_context(|| "Couldn't get byte from buffer")?;
390 *buffer_byte = *chunk_byte;
391 }
392
393 Ok(())
394 }
395}
396
397impl Drop for PTY {
398 fn drop(&mut self) {
399 tracing::debug!("PTY dropped, broadcasting `End` signal.");
400
401 let result: Result<_, crate::errors::PTYError> = self
402 .control_tx
403 .send(crate::Protocol::End)
404 .with_whatever_context(|err| {
405 format!("Couldn't send shutdown signal after PTY finished: {err:?}")
406 });
407
408 if let Err(err) = result {
409 tracing::error!("{err:?}");
410 }
411 }
412}
413
414#[cfg(test)]
415#[expect(clippy::print_stderr, reason = "Tests aren't so strict")]
416mod test {
417 use super::*;
418
419 fn run(
420 command: Vec<OsString>,
421 ) -> (
422 tokio::task::JoinHandle<std::string::String>,
423 mpsc::Sender<BytesFromSTDIN>,
424 ) {
425 let (pty_output_tx, mut pty_output_rx) = mpsc::channel::<BytesFromPTY>(8);
429 let (pty_input_tx, pty_input_rx) = mpsc::channel::<BytesFromSTDIN>(1);
430 let (_, internal_input_rx) = mpsc::channel::<BytesFromSTDIN>(8);
431 let (protocol_tx, _) = tokio::sync::broadcast::channel(16);
432
433 let output_task = tokio::spawn(async move {
434 tracing::debug!("TEST: Output listener loop starting...");
435 let mut result: Vec<u8> = vec![];
436
437 while let Some(bytes) = pty_output_rx.recv().await {
440 result.extend(bytes.iter().copied());
441 }
442
443 let output = String::from_utf8_lossy(&result).into_owned();
444 tracing::debug!("TEST: `interactive()` output: {output:?}");
445 output
446 });
447
448 tokio::spawn(async move {
449 tracing::debug!("TEST: PTY.run() starting...");
450 let pty = PTY {
451 command,
452 width: 10,
453 height: 10,
454 output_tx: pty_output_tx,
455 control_tx: protocol_tx.clone(),
456 };
457 let result = pty.run(pty_input_rx, internal_input_rx).await;
458 if let Err(err) = result {
459 tracing::warn!("PTY (for tests) handle: {err:?}");
460 }
461 tracing::debug!("Test PTY.run() done");
462 });
463
464 tracing::debug!("TEST: Leaving run helper...");
465 (output_task, pty_input_tx)
466 }
467
468 fn cat_earth_command() -> String {
470 let cat_command = "cat";
471 let path = crate::tests::helpers::workspace_dir()
472 .join("shadow-terminal")
473 .join("src")
474 .join("tests")
475 .join("cat_me.txt");
476
477 #[cfg(not(target_os = "windows"))]
478 let sleep = "&& sleep 0.5";
479 #[cfg(target_os = "windows")]
480 let sleep = "; Start-Sleep -Milliseconds 5";
481
482 format!("{cat_command} {} {sleep}", path.display())
483 }
484
485 fn stdin_bytes(input: &str) -> BytesFromSTDIN {
486 let mut buffer: BytesFromSTDIN = [0; 128];
487 #[expect(
488 clippy::indexing_slicing,
489 reason = "How do I do a range slice with []?"
490 )]
491 buffer[..input.len()].copy_from_slice(input.as_bytes());
492 buffer
493 }
494
495 #[tokio::test(flavor = "multi_thread")]
496 async fn basic_output() {
497 let mut command = crate::tests::helpers::get_canonical_shell();
498
499 #[cfg(not(target_os = "windows"))]
500 command.push("-c".into());
501 #[cfg(target_os = "windows")]
502 command.push("-Command".into());
503
504 command.push(cat_earth_command().into());
505
506 let (output_task, _) = run(command);
507 let result = output_task.await.unwrap();
508 eprintln!("{result}");
509
510 assert!(result.contains("earth"));
511 }
512
513 #[cfg(not(target_os = "windows"))]
514 #[tokio::test(flavor = "multi_thread")]
515 async fn interactive() {
516 let (output_task, input_channel) = run(crate::tests::helpers::get_canonical_shell());
517 tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
518
519 #[cfg(not(target_os = "windows"))]
520 let exit = "&& exit";
521 #[cfg(target_os = "windows")]
522 let exit = "; exit";
523 let command = format!("{} {exit}\n", cat_earth_command());
524
525 input_channel
526 .send(stdin_bytes(command.as_ref()))
527 .await
528 .unwrap();
529 tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
530 let result = output_task.await.unwrap();
531 eprintln!("{result}");
532
533 assert!(result.contains("earth"));
534 }
535}