virtual_tty_pty/
lib.rs

1use libc::{self, winsize};
2use std::io;
3use std::os::unix::io::RawFd;
4use std::process::{Child, Command};
5use std::sync::{Arc, Mutex};
6use std::thread::{self, JoinHandle};
7
8// Re-export the core VirtualTty
9pub use virtual_tty::VirtualTty;
10
11pub struct PtyAdapter {
12    virtual_tty: Arc<Mutex<VirtualTty>>,
13    master_fd: Option<RawFd>,
14    slave_fd: Option<RawFd>,
15    reader_thread: Option<JoinHandle<()>>,
16}
17
18impl PtyAdapter {
19    pub fn new(width: usize, height: usize) -> Self {
20        Self {
21            virtual_tty: Arc::new(Mutex::new(VirtualTty::new(width, height))),
22            master_fd: None,
23            slave_fd: None,
24            reader_thread: None,
25        }
26    }
27
28    pub fn from_virtual_tty(virtual_tty: VirtualTty) -> Self {
29        Self {
30            virtual_tty: Arc::new(Mutex::new(virtual_tty)),
31            master_fd: None,
32            slave_fd: None,
33            reader_thread: None,
34        }
35    }
36
37    pub fn get_virtual_tty(&self) -> Arc<Mutex<VirtualTty>> {
38        self.virtual_tty.clone()
39    }
40
41    pub fn get_snapshot(&self) -> String {
42        self.virtual_tty.lock().unwrap().get_snapshot()
43    }
44
45    pub fn get_size(&self) -> (usize, usize) {
46        self.virtual_tty.lock().unwrap().get_size()
47    }
48
49    fn create_pty(&mut self) -> io::Result<()> {
50        if self.master_fd.is_some() {
51            return Ok(());
52        }
53
54        let (width, height) = self.get_size();
55
56        unsafe {
57            let mut master: libc::c_int = 0;
58            let mut slave: libc::c_int = 0;
59
60            // Open a new PTY
61            let result = libc::openpty(
62                &mut master,
63                &mut slave,
64                std::ptr::null_mut(),
65                std::ptr::null_mut(),
66                std::ptr::null_mut(),
67            );
68
69            if result != 0 {
70                return Err(io::Error::last_os_error());
71            }
72
73            // Set window size
74            let ws = winsize {
75                ws_row: height as u16,
76                ws_col: width as u16,
77                ws_xpixel: 0,
78                ws_ypixel: 0,
79            };
80
81            let _ = libc::ioctl(master, libc::TIOCSWINSZ, &ws);
82
83            self.master_fd = Some(master);
84            self.slave_fd = Some(slave);
85        }
86
87        Ok(())
88    }
89
90    fn start_reader_thread(&mut self) {
91        if self.reader_thread.is_some() {
92            return;
93        }
94
95        let master_fd = match self.master_fd {
96            Some(fd) => fd,
97            None => return,
98        };
99
100        let virtual_tty = self.virtual_tty.clone();
101
102        let reader_thread = thread::spawn(move || {
103            let mut read_buffer = [0u8; 4096];
104
105            loop {
106                let n = unsafe {
107                    libc::read(
108                        master_fd,
109                        read_buffer.as_mut_ptr() as *mut libc::c_void,
110                        read_buffer.len(),
111                    )
112                };
113                match n {
114                    0 => break, // EOF
115                    n if n > 0 => {
116                        let data = String::from_utf8_lossy(&read_buffer[..n as usize]);
117                        virtual_tty.lock().unwrap().stdout_write(&data);
118                    }
119                    _ => break,
120                }
121            }
122        });
123
124        self.reader_thread = Some(reader_thread);
125    }
126
127    pub fn spawn_command(&mut self, cmd: &mut Command) -> io::Result<Child> {
128        self.create_pty()?;
129        self.start_reader_thread();
130
131        let slave_fd = self
132            .slave_fd
133            .ok_or_else(|| io::Error::other("No slave PTY"))?;
134
135        // Duplicate the slave FD for stdin/stdout/stderr to avoid closing issues
136        let slave_stdin = unsafe { libc::dup(slave_fd) };
137        let slave_stdout = unsafe { libc::dup(slave_fd) };
138        let slave_stderr = unsafe { libc::dup(slave_fd) };
139
140        if slave_stdin < 0 || slave_stdout < 0 || slave_stderr < 0 {
141            return Err(io::Error::last_os_error());
142        }
143
144        unsafe {
145            use std::os::unix::io::FromRawFd;
146            use std::process::Stdio;
147
148            cmd.stdin(Stdio::from_raw_fd(slave_stdin))
149                .stdout(Stdio::from_raw_fd(slave_stdout))
150                .stderr(Stdio::from_raw_fd(slave_stderr));
151        }
152
153        cmd.spawn()
154    }
155
156    pub fn send_input(&mut self, input: &[u8]) -> io::Result<()> {
157        let master_fd = self
158            .master_fd
159            .ok_or_else(|| io::Error::other("No master PTY"))?;
160
161        let result = unsafe {
162            libc::write(
163                master_fd,
164                input.as_ptr() as *const libc::c_void,
165                input.len(),
166            )
167        };
168        if result < 0 {
169            Err(io::Error::last_os_error())
170        } else {
171            Ok(())
172        }
173    }
174
175    /// Convenience method to send string input
176    pub fn send_input_str(&mut self, input: &str) -> io::Result<()> {
177        self.send_input(input.as_bytes())
178    }
179
180    /// Wait for any running processes to complete and reader thread to finish
181    pub fn wait_for_completion(&mut self) {
182        if let Some(thread) = self.reader_thread.take() {
183            let _ = thread.join();
184        }
185    }
186}
187
188impl Drop for PtyAdapter {
189    fn drop(&mut self) {
190        // Close PTY file descriptors
191        if let Some(fd) = self.master_fd {
192            unsafe {
193                libc::close(fd);
194            }
195        }
196        if let Some(fd) = self.slave_fd {
197            unsafe {
198                libc::close(fd);
199            }
200        }
201
202        // Wait for reader thread to finish
203        if let Some(thread) = self.reader_thread.take() {
204            let _ = thread.join();
205        }
206    }
207}