Skip to main content

rust_expect/backend/
pty.rs

1//! PTY backend for local process spawning.
2//!
3//! This module provides the PTY backend that uses the rust-pty crate
4//! to spawn local processes with pseudo-terminal support.
5
6use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::config::SessionConfig;
13use crate::error::{ExpectError, Result, SpawnError};
14
15/// A PTY-based transport for local process communication.
16pub struct PtyTransport {
17    /// The PTY reader half.
18    reader: Box<dyn AsyncRead + Unpin + Send>,
19    /// The PTY writer half.
20    writer: Box<dyn AsyncWrite + Unpin + Send>,
21    /// Process ID.
22    pid: Option<u32>,
23}
24
25impl PtyTransport {
26    /// Create a new PTY transport from reader and writer.
27    pub fn new<R, W>(reader: R, writer: W) -> Self
28    where
29        R: AsyncRead + Unpin + Send + 'static,
30        W: AsyncWrite + Unpin + Send + 'static,
31    {
32        Self {
33            reader: Box::new(reader),
34            writer: Box::new(writer),
35            pid: None,
36        }
37    }
38
39    /// Set the process ID.
40    pub const fn set_pid(&mut self, pid: u32) {
41        self.pid = Some(pid);
42    }
43
44    /// Get the process ID.
45    #[must_use]
46    pub const fn pid(&self) -> Option<u32> {
47        self.pid
48    }
49}
50
51impl AsyncRead for PtyTransport {
52    fn poll_read(
53        mut self: Pin<&mut Self>,
54        cx: &mut Context<'_>,
55        buf: &mut ReadBuf<'_>,
56    ) -> Poll<io::Result<()>> {
57        Pin::new(&mut self.reader).poll_read(cx, buf)
58    }
59}
60
61impl AsyncWrite for PtyTransport {
62    fn poll_write(
63        mut self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65        buf: &[u8],
66    ) -> Poll<io::Result<usize>> {
67        Pin::new(&mut self.writer).poll_write(cx, buf)
68    }
69
70    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71        Pin::new(&mut self.writer).poll_flush(cx)
72    }
73
74    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75        Pin::new(&mut self.writer).poll_shutdown(cx)
76    }
77}
78
79/// Configuration for PTY spawning.
80#[derive(Debug, Clone)]
81pub struct PtyConfig {
82    /// Terminal dimensions (cols, rows).
83    pub dimensions: (u16, u16),
84    /// Whether to use a login shell.
85    pub login_shell: bool,
86    /// Environment variable handling.
87    pub env_mode: EnvMode,
88    /// Environment variables to apply per `env_mode` (overlay for `Extend`,
89    /// the full set for `Clear`, ignored for `Inherit`).
90    pub env: std::collections::HashMap<String, String>,
91}
92
93impl Default for PtyConfig {
94    fn default() -> Self {
95        Self {
96            dimensions: (80, 24),
97            login_shell: false,
98            env_mode: EnvMode::Inherit,
99            env: std::collections::HashMap::new(),
100        }
101    }
102}
103
104impl From<&SessionConfig> for PtyConfig {
105    fn from(config: &SessionConfig) -> Self {
106        Self {
107            dimensions: config.dimensions,
108            login_shell: false,
109            env_mode: if config.env.is_empty() {
110                EnvMode::Inherit
111            } else {
112                EnvMode::Extend
113            },
114            env: config.env.clone(),
115        }
116    }
117}
118
119/// Environment variable handling mode.
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum EnvMode {
122    /// Inherit all environment variables from parent.
123    Inherit,
124    /// Clear environment and only use specified variables.
125    Clear,
126    /// Inherit and extend with specified variables.
127    Extend,
128}
129
130/// Apply `env_mode` plus the user-supplied overrides to the calling
131/// process's environment.
132///
133/// **Must only be called in a child process after `fork`** — it mutates
134/// global `environ` state via `setenv`/`clearenv`/`unsetenv`, which is
135/// safe only because the child is single-threaded at this point (between
136/// fork and exec).
137///
138/// - `Inherit`: leave the inherited parent env in place; just apply overrides.
139/// - `Clear`:   wipe environ (Linux: `clearenv`; elsewhere: walk + `unsetenv`)
140///   then apply overrides.
141/// - `Extend`:  same as Inherit semantically; overrides overwrite existing.
142#[cfg(unix)]
143#[allow(unsafe_code)]
144unsafe fn apply_env_in_child(
145    env_mode: EnvMode,
146    env_pairs: &[(std::ffi::CString, std::ffi::CString)],
147) {
148    // SAFETY: caller (this function's doc-comment contract) guarantees we are
149    // executing post-fork, pre-exec in a child process, which is single-threaded.
150    // Mutating `environ` via clearenv/setenv/unsetenv is therefore race-free.
151    unsafe {
152        match env_mode {
153            EnvMode::Inherit | EnvMode::Extend => {}
154            EnvMode::Clear => {
155                #[cfg(target_os = "linux")]
156                {
157                    libc::clearenv();
158                }
159                #[cfg(not(target_os = "linux"))]
160                {
161                    // Collect every existing key into owned CStrings BEFORE we
162                    // start calling unsetenv. unsetenv mutates the global
163                    // `environ` array — entries shift, the array can be
164                    // reallocated — so iterating it concurrently with
165                    // mutation is fragile and libc-dependent. Snapshotting
166                    // first sidesteps the issue entirely, and the keys can
167                    // be of arbitrary length without truncation.
168                    // Edition 2024 requires extern blocks declaring foreign
169                    // statics to be wrapped in `unsafe extern`.
170                    unsafe extern "C" {
171                        static mut environ: *mut *mut libc::c_char;
172                    }
173                    let mut names: Vec<std::ffi::CString> = Vec::new();
174                    if !environ.is_null() {
175                        let mut p = environ;
176                        while !(*p).is_null() {
177                            let entry = *p;
178                            // Find the '=' separator (or NUL if malformed).
179                            let mut len = 0usize;
180                            while *entry.add(len) != 0 && *entry.add(len) != b'=' as libc::c_char {
181                                len += 1;
182                            }
183                            if len > 0 {
184                                let bytes = std::slice::from_raw_parts(entry.cast::<u8>(), len);
185                                if let Ok(c) = std::ffi::CString::new(bytes) {
186                                    names.push(c);
187                                }
188                            }
189                            p = p.add(1);
190                        }
191                    }
192                    for name in &names {
193                        libc::unsetenv(name.as_ptr());
194                    }
195                }
196            }
197        }
198        for (k, v) in env_pairs {
199            libc::setenv(k.as_ptr(), v.as_ptr(), 1);
200        }
201    }
202}
203
204/// Validate environment-variable overrides and convert them to pairs of
205/// `CString` that can be safely applied between fork and exec on Unix.
206///
207/// `setenv` allocates, so the canonical safety model after `fork` is to
208/// only use async-signal-safe functions. We do still call `setenv` in the
209/// child — this codebase forks before any tokio worker threads exist, so
210/// allocator state is single-threaded and the call is sound in practice.
211/// Pre-building these `CString`s here means we don't have to allocate in
212/// the child on the keys or values themselves.
213#[cfg(unix)]
214fn build_env_cstrings(
215    env: &std::collections::HashMap<String, String>,
216) -> Result<Vec<(std::ffi::CString, std::ffi::CString)>> {
217    use std::ffi::CString;
218
219    let mut pairs: Vec<(CString, CString)> = Vec::with_capacity(env.len());
220    for (k, v) in env {
221        if k.contains('=') {
222            return Err(ExpectError::Spawn(SpawnError::InvalidArgument {
223                kind: "env key".to_string(),
224                value: k.clone(),
225                reason: "env key contains '='".to_string(),
226            }));
227        }
228        let key = CString::new(k.as_str()).map_err(|_| {
229            ExpectError::Spawn(SpawnError::InvalidArgument {
230                kind: "env key".to_string(),
231                value: k.clone(),
232                reason: "env key contains null byte".to_string(),
233            })
234        })?;
235        let val = CString::new(v.as_str()).map_err(|_| {
236            ExpectError::Spawn(SpawnError::InvalidArgument {
237                kind: "env value".to_string(),
238                value: v.clone(),
239                reason: "env value contains null byte".to_string(),
240            })
241        })?;
242        pairs.push((key, val));
243    }
244    Ok(pairs)
245}
246
247/// Spawner for PTY sessions.
248pub struct PtySpawner {
249    config: PtyConfig,
250}
251
252impl PtySpawner {
253    /// Create a new PTY spawner with default configuration.
254    #[must_use]
255    pub fn new() -> Self {
256        Self {
257            config: PtyConfig::default(),
258        }
259    }
260
261    /// Create a new PTY spawner with custom configuration.
262    #[must_use]
263    pub const fn with_config(config: PtyConfig) -> Self {
264        Self { config }
265    }
266
267    /// Set the terminal dimensions.
268    pub const fn set_dimensions(&mut self, cols: u16, rows: u16) {
269        self.config.dimensions = (cols, rows);
270    }
271
272    /// Spawn a command.
273    ///
274    /// # Runtime requirement (Unix)
275    ///
276    /// The Unix implementation forks and then calls `setenv` / `unsetenv` /
277    /// `clearenv` between fork and exec to apply the configured env mode.
278    /// Those libc functions are **not** async-signal-safe — they allocate
279    /// — so the post-fork window in the child must run on a single thread
280    /// for the call to be sound. In this crate that is true because
281    /// callers reach `spawn` directly from a fresh `tokio::main` or
282    /// equivalent before any background thread has captured the
283    /// allocator lock at the fork point.
284    ///
285    /// **If you embed this crate in a host that pre-spawns worker
286    /// threads (for example, a multi-threaded scheduler that's already
287    /// running by the time you call `Session::spawn`)**, the assumption
288    /// breaks: another thread may hold the allocator lock at the moment
289    /// of `fork`, and the child can deadlock or corrupt heap state on
290    /// the first `setenv` call. In that environment, prefer a
291    /// `posix_spawn`-based spawner or a pre-fork sentinel-pipe helper.
292    ///
293    /// # Errors
294    ///
295    /// Returns an error if:
296    /// - The command or arguments contain null bytes
297    /// - PTY allocation fails
298    /// - Fork fails
299    /// - Exec fails (child exits with code 1)
300    #[cfg(unix)]
301    #[allow(unsafe_code)]
302    #[allow(clippy::unused_async)]
303    pub async fn spawn(&self, command: &str, args: &[String]) -> Result<PtyHandle> {
304        use std::ffi::CString;
305
306        // Validate and create CStrings BEFORE forking so we can return proper errors
307        let cmd_cstring = CString::new(command).map_err(|_| {
308            ExpectError::Spawn(SpawnError::InvalidArgument {
309                kind: "command".to_string(),
310                value: command.to_string(),
311                reason: "command contains null byte".to_string(),
312            })
313        })?;
314
315        let mut argv_cstrings: Vec<CString> = Vec::with_capacity(args.len() + 1);
316        argv_cstrings.push(cmd_cstring.clone());
317
318        for (idx, arg) in args.iter().enumerate() {
319            let arg_cstring = CString::new(arg.as_str()).map_err(|_| {
320                ExpectError::Spawn(SpawnError::InvalidArgument {
321                    kind: format!("argument[{idx}]"),
322                    value: arg.clone(),
323                    reason: "argument contains null byte".to_string(),
324                })
325            })?;
326            argv_cstrings.push(arg_cstring);
327        }
328
329        // Validate env entries before fork so we can return a clean error.
330        let env_pairs = build_env_cstrings(&self.config.env)?;
331        let env_mode = self.config.env_mode;
332
333        // Create PTY pair
334        // SAFETY: openpty() is called with valid pointers to stack-allocated integers.
335        // The null pointers for name, termp, and winp are explicitly allowed per POSIX.
336        // We check the return value and handle errors appropriately.
337        let pty_result = unsafe {
338            let mut master: libc::c_int = 0;
339            let mut slave: libc::c_int = 0;
340
341            // Open PTY
342            if libc::openpty(
343                &raw mut master,
344                &raw mut slave,
345                std::ptr::null_mut(),
346                std::ptr::null_mut(),
347                std::ptr::null_mut(),
348            ) != 0
349            {
350                return Err(ExpectError::Spawn(SpawnError::PtyAllocation {
351                    reason: "Failed to open PTY".to_string(),
352                }));
353            }
354
355            (master, slave)
356        };
357
358        let (master_fd, slave_fd) = pty_result;
359
360        // Fork the process
361        // SAFETY: fork() is safe to call at this point as we have no threads running
362        // that could hold locks. The child process will immediately set up its
363        // environment and exec into the target program.
364        let pid = unsafe { libc::fork() };
365
366        match pid {
367            -1 => Err(ExpectError::Spawn(SpawnError::Io(
368                io::Error::last_os_error(),
369            ))),
370            0 => {
371                // Child process
372                // SAFETY: This runs in the forked child process only. We:
373                // - Close the master fd (not needed in child)
374                // - Create a new session with setsid()
375                // - Set the slave as the controlling terminal via TIOCSCTTY
376                // - Redirect stdin/stdout/stderr to the slave pty
377                // - Close the original slave fd if it's not 0, 1, or 2
378                // - Execute the target command (never returns on success)
379                // - Exit with code 1 if exec fails
380                // All file descriptors are valid and owned by this process.
381                unsafe {
382                    libc::close(master_fd);
383                    libc::setsid();
384                    // Cast TIOCSCTTY to c_ulong for macOS compatibility (u32 -> u64)
385                    libc::ioctl(slave_fd, libc::TIOCSCTTY as libc::c_ulong, 0);
386
387                    libc::dup2(slave_fd, 0);
388                    libc::dup2(slave_fd, 1);
389                    libc::dup2(slave_fd, 2);
390
391                    if slave_fd > 2 {
392                        libc::close(slave_fd);
393                    }
394
395                    // Apply env_mode + overrides before exec.
396                    apply_env_in_child(env_mode, &env_pairs);
397
398                    // Use pre-validated CStrings (validated before fork)
399                    let argv_ptrs: Vec<*const libc::c_char> = argv_cstrings
400                        .iter()
401                        .map(|s| s.as_ptr())
402                        .chain(std::iter::once(std::ptr::null()))
403                        .collect();
404
405                    libc::execvp(cmd_cstring.as_ptr(), argv_ptrs.as_ptr());
406                    libc::_exit(1);
407                }
408            }
409            child_pid => {
410                // Parent process
411                // SAFETY: slave_fd is a valid file descriptor obtained from openpty().
412                // The parent doesn't need the slave end; only the child uses it.
413                unsafe {
414                    libc::close(slave_fd);
415                }
416
417                // Set non-blocking
418                // SAFETY: master_fd is a valid file descriptor from openpty().
419                // F_GETFL and F_SETFL with O_NONBLOCK are standard operations
420                // that don't violate any safety invariants.
421                unsafe {
422                    let flags = libc::fcntl(master_fd, libc::F_GETFL);
423                    libc::fcntl(master_fd, libc::F_SETFL, flags | libc::O_NONBLOCK);
424                }
425
426                Ok(PtyHandle {
427                    master_fd,
428                    pid: child_pid as u32,
429                    dimensions: self.config.dimensions,
430                })
431            }
432        }
433    }
434
435    /// Spawn a command on Windows using ConPTY.
436    ///
437    /// # Errors
438    ///
439    /// Returns an error if:
440    /// - ConPTY is not available (Windows version too old)
441    /// - PTY allocation fails
442    /// - Process spawning fails
443    #[cfg(windows)]
444    pub async fn spawn(&self, command: &str, args: &[String]) -> Result<WindowsPtyHandle> {
445        use rust_pty::{PtySystem, WindowsPtySystem};
446
447        // Build env per env_mode:
448        // - Inherit: env: None (rust-pty inherits parent env), but if we also
449        //   have overrides, we need to inherit + overlay → build a full map.
450        // - Clear:   env: Some(our overrides) — parent env discarded.
451        // - Extend:  env: Some(parent + our overrides), parent first so ours win.
452        let built_env: Option<std::collections::HashMap<std::ffi::OsString, std::ffi::OsString>> =
453            match self.config.env_mode {
454                EnvMode::Inherit if self.config.env.is_empty() => None,
455                EnvMode::Inherit | EnvMode::Extend => {
456                    let mut m: std::collections::HashMap<_, _> = std::env::vars_os().collect();
457                    for (k, v) in &self.config.env {
458                        m.insert(std::ffi::OsString::from(k), std::ffi::OsString::from(v));
459                    }
460                    Some(m)
461                }
462                EnvMode::Clear => Some(
463                    self.config
464                        .env
465                        .iter()
466                        .map(|(k, v)| (std::ffi::OsString::from(k), std::ffi::OsString::from(v)))
467                        .collect(),
468                ),
469            };
470
471        // Create configuration for rust-pty
472        let pty_config = rust_pty::PtyConfig {
473            window_size: self.config.dimensions,
474            env: match self.config.env_mode {
475                EnvMode::Clear if self.config.env.is_empty() => {
476                    Some(std::collections::HashMap::new())
477                }
478                _ => built_env,
479            },
480            ..Default::default()
481        };
482
483        // Spawn using rust-pty's Windows implementation
484        let (master, child) =
485            WindowsPtySystem::spawn(command, args.iter().map(|s| s.as_str()), &pty_config)
486                .await
487                .map_err(|e| {
488                    ExpectError::Spawn(SpawnError::PtyAllocation {
489                        reason: format!("Windows ConPTY spawn failed: {e}"),
490                    })
491                })?;
492
493        Ok(WindowsPtyHandle {
494            master,
495            child,
496            dimensions: self.config.dimensions,
497        })
498    }
499}
500
501impl Default for PtySpawner {
502    fn default() -> Self {
503        Self::new()
504    }
505}
506
507/// Handle to a spawned PTY process (Unix).
508#[cfg(unix)]
509#[derive(Debug)]
510pub struct PtyHandle {
511    /// Master PTY file descriptor.
512    master_fd: i32,
513    /// Process ID.
514    pid: u32,
515    /// Terminal dimensions (cols, rows).
516    dimensions: (u16, u16),
517}
518
519/// Handle to a spawned PTY process (Windows).
520#[cfg(windows)]
521pub struct WindowsPtyHandle {
522    /// The PTY master from rust-pty.
523    pub(crate) master: rust_pty::WindowsPtyMaster,
524    /// The child process handle.
525    pub(crate) child: rust_pty::WindowsPtyChild,
526    /// Terminal dimensions (cols, rows).
527    dimensions: (u16, u16),
528}
529
530#[cfg(windows)]
531impl std::fmt::Debug for WindowsPtyHandle {
532    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
533        f.debug_struct("WindowsPtyHandle")
534            .field("dimensions", &self.dimensions)
535            .finish_non_exhaustive()
536    }
537}
538
539#[cfg(unix)]
540impl PtyHandle {
541    /// Get the process ID.
542    #[must_use]
543    pub const fn pid(&self) -> u32 {
544        self.pid
545    }
546
547    /// Get the terminal dimensions.
548    #[must_use]
549    pub const fn dimensions(&self) -> (u16, u16) {
550        self.dimensions
551    }
552
553    /// Resize the terminal.
554    #[allow(unsafe_code)]
555    pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
556        let winsize = libc::winsize {
557            ws_row: rows,
558            ws_col: cols,
559            ws_xpixel: 0,
560            ws_ypixel: 0,
561        };
562
563        // SAFETY: master_fd is a valid PTY file descriptor stored in self.
564        // TIOCSWINSZ is a valid ioctl command for PTYs that sets the window size.
565        // winsize is a valid pointer to a properly initialized struct on the stack.
566        // Cast to c_ulong for macOS compatibility (u32 -> u64).
567        let result =
568            unsafe { libc::ioctl(self.master_fd, libc::TIOCSWINSZ as libc::c_ulong, &winsize) };
569
570        if result != 0 {
571            Err(ExpectError::Io(io::Error::last_os_error()))
572        } else {
573            self.dimensions = (cols, rows);
574            Ok(())
575        }
576    }
577
578    /// Wait for the process to exit.
579    #[allow(unsafe_code)]
580    pub fn wait(&self) -> Result<i32> {
581        let mut status: libc::c_int = 0;
582        // SAFETY: self.pid is a valid process ID from fork().
583        // status is a valid pointer to a stack-allocated integer.
584        // The options argument (0) means blocking wait, which is valid.
585        let result = unsafe { libc::waitpid(self.pid as i32, &raw mut status, 0) };
586
587        if result == -1 {
588            Err(ExpectError::Io(io::Error::last_os_error()))
589        } else if libc::WIFEXITED(status) {
590            Ok(libc::WEXITSTATUS(status))
591        } else if libc::WIFSIGNALED(status) {
592            Ok(128 + libc::WTERMSIG(status))
593        } else {
594            Ok(-1)
595        }
596    }
597
598    /// Send a signal to the process.
599    #[allow(unsafe_code)]
600    pub fn signal(&self, signal: i32) -> Result<()> {
601        // SAFETY: self.pid is a valid process ID from fork().
602        // The signal is passed from the caller and must be a valid signal number.
603        // kill() is safe to call with any PID; it returns an error for invalid PIDs.
604        let result = unsafe { libc::kill(self.pid as i32, signal) };
605        if result != 0 {
606            Err(ExpectError::Io(io::Error::last_os_error()))
607        } else {
608            Ok(())
609        }
610    }
611
612    /// Kill the process.
613    pub fn kill(&self) -> Result<()> {
614        self.signal(libc::SIGKILL)
615    }
616}
617
618#[cfg(windows)]
619impl WindowsPtyHandle {
620    /// Get the process ID.
621    #[must_use]
622    pub fn pid(&self) -> u32 {
623        self.child.pid()
624    }
625
626    /// Get the terminal dimensions.
627    #[must_use]
628    pub const fn dimensions(&self) -> (u16, u16) {
629        self.dimensions
630    }
631
632    /// Resize the terminal.
633    pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
634        use rust_pty::{PtyMaster, WindowSize};
635        let size = WindowSize::new(cols, rows);
636        self.master
637            .resize(size)
638            .map_err(|e| ExpectError::Io(io::Error::other(format!("resize failed: {e}"))))?;
639        self.dimensions = (cols, rows);
640        Ok(())
641    }
642
643    /// Check if the child process is still running.
644    #[must_use]
645    pub fn is_running(&self) -> bool {
646        self.child.is_running()
647    }
648
649    /// Kill the process.
650    pub fn kill(&mut self) -> Result<()> {
651        self.child
652            .kill()
653            .map_err(|e| ExpectError::Io(io::Error::other(format!("kill failed: {e}"))))
654    }
655}
656
657#[cfg(unix)]
658impl Drop for PtyHandle {
659    #[allow(unsafe_code)]
660    fn drop(&mut self) {
661        // Close the master fd
662        // SAFETY: master_fd is a valid file descriptor obtained from openpty()
663        // and stored in this struct. It has not been closed elsewhere as we own it.
664        // Closing in Drop ensures the fd is released when the handle is dropped.
665        unsafe {
666            libc::close(self.master_fd);
667        }
668    }
669}
670
671/// Async wrapper around a PTY file descriptor for use with Tokio.
672///
673/// This provides `AsyncRead` and `AsyncWrite` implementations that
674/// integrate with the Tokio runtime.
675#[cfg(unix)]
676pub struct AsyncPty {
677    /// The async file descriptor wrapper.
678    inner: tokio::io::unix::AsyncFd<std::os::unix::io::RawFd>,
679    /// Process ID.
680    pid: u32,
681    /// Terminal dimensions.
682    dimensions: (u16, u16),
683}
684
685#[cfg(unix)]
686impl AsyncPty {
687    /// Create a new async PTY wrapper from a `PtyHandle`.
688    ///
689    /// Takes ownership of the `PtyHandle`'s file descriptor.
690    ///
691    /// # Errors
692    ///
693    /// Returns an error if the `AsyncFd` cannot be created.
694    pub fn from_handle(handle: PtyHandle) -> io::Result<Self> {
695        let fd = handle.master_fd;
696        let pid = handle.pid;
697        let dimensions = handle.dimensions;
698
699        // Prevent the original handle from closing the fd
700        std::mem::forget(handle);
701
702        let inner = tokio::io::unix::AsyncFd::new(fd)?;
703        Ok(Self {
704            inner,
705            pid,
706            dimensions,
707        })
708    }
709
710    /// Get the process ID.
711    #[must_use]
712    pub const fn pid(&self) -> u32 {
713        self.pid
714    }
715
716    /// Get the terminal dimensions.
717    #[must_use]
718    pub const fn dimensions(&self) -> (u16, u16) {
719        self.dimensions
720    }
721
722    /// Resize the terminal.
723    #[allow(unsafe_code)]
724    pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
725        let winsize = libc::winsize {
726            ws_row: rows,
727            ws_col: cols,
728            ws_xpixel: 0,
729            ws_ypixel: 0,
730        };
731
732        // SAFETY: The fd is valid and TIOCSWINSZ is a valid ioctl for PTYs.
733        // Cast to c_ulong for macOS compatibility (u32 -> u64).
734        let result = unsafe {
735            libc::ioctl(
736                *self.inner.get_ref(),
737                libc::TIOCSWINSZ as libc::c_ulong,
738                &winsize,
739            )
740        };
741
742        if result != 0 {
743            Err(ExpectError::Io(io::Error::last_os_error()))
744        } else {
745            self.dimensions = (cols, rows);
746            Ok(())
747        }
748    }
749
750    /// Send a signal to the child process.
751    #[allow(unsafe_code)]
752    pub fn signal(&self, signal: i32) -> Result<()> {
753        // SAFETY: pid is a valid process ID from fork().
754        let result = unsafe { libc::kill(self.pid as i32, signal) };
755        if result != 0 {
756            Err(ExpectError::Io(io::Error::last_os_error()))
757        } else {
758            Ok(())
759        }
760    }
761
762    /// Kill the child process.
763    pub fn kill(&self) -> Result<()> {
764        self.signal(libc::SIGKILL)
765    }
766}
767
768#[cfg(unix)]
769impl AsyncRead for AsyncPty {
770    #[allow(unsafe_code)]
771    fn poll_read(
772        self: Pin<&mut Self>,
773        cx: &mut Context<'_>,
774        buf: &mut ReadBuf<'_>,
775    ) -> Poll<io::Result<()>> {
776        loop {
777            let mut guard = match self.inner.poll_read_ready(cx) {
778                Poll::Ready(Ok(guard)) => guard,
779                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
780                Poll::Pending => return Poll::Pending,
781            };
782
783            let fd = *self.inner.get_ref();
784            let unfilled = buf.initialize_unfilled();
785
786            // SAFETY: fd is a valid file descriptor, unfilled is a valid buffer.
787            let result = unsafe {
788                libc::read(
789                    fd,
790                    unfilled.as_mut_ptr().cast::<libc::c_void>(),
791                    unfilled.len(),
792                )
793            };
794
795            if result >= 0 {
796                buf.advance(result as usize);
797                return Poll::Ready(Ok(()));
798            }
799
800            let err = io::Error::last_os_error();
801            if err.kind() == io::ErrorKind::WouldBlock {
802                guard.clear_ready();
803                continue;
804            }
805            return Poll::Ready(Err(err));
806        }
807    }
808}
809
810#[cfg(unix)]
811impl AsyncWrite for AsyncPty {
812    #[allow(unsafe_code)]
813    fn poll_write(
814        self: Pin<&mut Self>,
815        cx: &mut Context<'_>,
816        buf: &[u8],
817    ) -> Poll<io::Result<usize>> {
818        loop {
819            let mut guard = match self.inner.poll_write_ready(cx) {
820                Poll::Ready(Ok(guard)) => guard,
821                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
822                Poll::Pending => return Poll::Pending,
823            };
824
825            let fd = *self.inner.get_ref();
826
827            // SAFETY: fd is a valid file descriptor, buf is a valid buffer.
828            let result = unsafe { libc::write(fd, buf.as_ptr().cast::<libc::c_void>(), buf.len()) };
829
830            if result >= 0 {
831                return Poll::Ready(Ok(result as usize));
832            }
833
834            let err = io::Error::last_os_error();
835            if err.kind() == io::ErrorKind::WouldBlock {
836                guard.clear_ready();
837                continue;
838            }
839            return Poll::Ready(Err(err));
840        }
841    }
842
843    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
844        // PTY doesn't need explicit flushing
845        Poll::Ready(Ok(()))
846    }
847
848    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
849        // Shutdown is handled by Drop
850        Poll::Ready(Ok(()))
851    }
852}
853
854#[cfg(unix)]
855impl Drop for AsyncPty {
856    #[allow(unsafe_code)]
857    fn drop(&mut self) {
858        // SAFETY: The fd is valid and owned by us.
859        unsafe {
860            libc::close(*self.inner.get_ref());
861        }
862    }
863}
864
865#[cfg(unix)]
866impl std::fmt::Debug for AsyncPty {
867    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
868        f.debug_struct("AsyncPty")
869            .field("fd", self.inner.get_ref())
870            .field("pid", &self.pid)
871            .field("dimensions", &self.dimensions)
872            .finish()
873    }
874}
875
876/// Async wrapper around Windows ConPTY for use with Tokio.
877///
878/// This wraps the rust-pty WindowsPtyMaster and provides the same interface
879/// as the Unix AsyncPty for consistent cross-platform Session usage.
880#[cfg(windows)]
881pub struct WindowsAsyncPty {
882    /// The underlying Windows PTY master.
883    master: rust_pty::WindowsPtyMaster,
884    /// The child process handle.
885    child: rust_pty::WindowsPtyChild,
886    /// Process ID.
887    pid: u32,
888    /// Terminal dimensions.
889    dimensions: (u16, u16),
890}
891
892#[cfg(windows)]
893impl WindowsAsyncPty {
894    /// Create a new Windows async PTY wrapper from a WindowsPtyHandle.
895    ///
896    /// Takes ownership of the handle.
897    pub fn from_handle(handle: WindowsPtyHandle) -> Self {
898        let pid = handle.child.pid();
899        let dimensions = handle.dimensions;
900        Self {
901            master: handle.master,
902            child: handle.child,
903            pid,
904            dimensions,
905        }
906    }
907
908    /// Get the process ID.
909    #[must_use]
910    pub const fn pid(&self) -> u32 {
911        self.pid
912    }
913
914    /// Get the terminal dimensions.
915    #[must_use]
916    pub const fn dimensions(&self) -> (u16, u16) {
917        self.dimensions
918    }
919
920    /// Resize the terminal.
921    pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
922        use rust_pty::{PtyMaster, WindowSize};
923        let size = WindowSize::new(cols, rows);
924        self.master
925            .resize(size)
926            .map_err(|e| ExpectError::Io(io::Error::other(format!("resize failed: {e}"))))?;
927        self.dimensions = (cols, rows);
928        Ok(())
929    }
930
931    /// Check if the child process is still running.
932    #[must_use]
933    pub fn is_running(&self) -> bool {
934        self.child.is_running()
935    }
936
937    /// Kill the child process.
938    pub fn kill(&mut self) -> Result<()> {
939        self.child
940            .kill()
941            .map_err(|e| ExpectError::Io(io::Error::other(format!("kill failed: {e}"))))
942    }
943}
944
945#[cfg(windows)]
946impl AsyncRead for WindowsAsyncPty {
947    fn poll_read(
948        mut self: Pin<&mut Self>,
949        cx: &mut Context<'_>,
950        buf: &mut ReadBuf<'_>,
951    ) -> Poll<io::Result<()>> {
952        // Delegate to the underlying WindowsPtyMaster which implements AsyncRead
953        Pin::new(&mut self.master).poll_read(cx, buf)
954    }
955}
956
957#[cfg(windows)]
958impl AsyncWrite for WindowsAsyncPty {
959    fn poll_write(
960        mut self: Pin<&mut Self>,
961        cx: &mut Context<'_>,
962        buf: &[u8],
963    ) -> Poll<io::Result<usize>> {
964        Pin::new(&mut self.master).poll_write(cx, buf)
965    }
966
967    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
968        Pin::new(&mut self.master).poll_flush(cx)
969    }
970
971    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
972        Pin::new(&mut self.master).poll_shutdown(cx)
973    }
974}
975
976#[cfg(windows)]
977impl std::fmt::Debug for WindowsAsyncPty {
978    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
979        f.debug_struct("WindowsAsyncPty")
980            .field("pid", &self.pid)
981            .field("dimensions", &self.dimensions)
982            .finish_non_exhaustive()
983    }
984}
985
986#[cfg(test)]
987mod tests {
988    use super::*;
989
990    #[test]
991    fn pty_config_default() {
992        let config = PtyConfig::default();
993        assert_eq!(config.dimensions.0, 80);
994        assert_eq!(config.dimensions.1, 24);
995        assert_eq!(config.env_mode, EnvMode::Inherit);
996    }
997
998    #[test]
999    fn pty_config_from_session() {
1000        let session_config = SessionConfig {
1001            dimensions: (120, 40),
1002            ..Default::default()
1003        };
1004
1005        let pty_config = PtyConfig::from(&session_config);
1006        assert_eq!(pty_config.dimensions.0, 120);
1007        assert_eq!(pty_config.dimensions.1, 40);
1008    }
1009
1010    #[cfg(unix)]
1011    #[tokio::test]
1012    async fn spawn_rejects_null_byte_in_command() {
1013        let spawner = PtySpawner::new();
1014        let result = spawner.spawn("test\0command", &[]).await;
1015
1016        assert!(result.is_err());
1017        let err = result.unwrap_err();
1018        let err_str = err.to_string();
1019        assert!(
1020            err_str.contains("null byte"),
1021            "Expected error about null byte, got: {err_str}"
1022        );
1023    }
1024
1025    #[cfg(unix)]
1026    #[tokio::test]
1027    async fn spawn_rejects_null_byte_in_args() {
1028        let spawner = PtySpawner::new();
1029        let result = spawner
1030            .spawn("/bin/echo", &["hello\0world".to_string()])
1031            .await;
1032
1033        assert!(result.is_err());
1034        let err = result.unwrap_err();
1035        let err_str = err.to_string();
1036        assert!(
1037            err_str.contains("null byte"),
1038            "Expected error about null byte, got: {err_str}"
1039        );
1040    }
1041}