1use libc::{self, winsize};
2use std::io;
3use std::os::unix::io::RawFd;
4use std::process::{Child, Command};
5use std::sync::{Arc, Mutex};
6use std::thread::{self, JoinHandle};
7
8pub use virtual_tty::VirtualTty;
10
11pub struct PtyAdapter {
12 virtual_tty: Arc<Mutex<VirtualTty>>,
13 master_fd: Option<RawFd>,
14 slave_fd: Option<RawFd>,
15 reader_thread: Option<JoinHandle<()>>,
16}
17
18impl PtyAdapter {
19 pub fn new(width: usize, height: usize) -> Self {
20 Self {
21 virtual_tty: Arc::new(Mutex::new(VirtualTty::new(width, height))),
22 master_fd: None,
23 slave_fd: None,
24 reader_thread: None,
25 }
26 }
27
28 pub fn from_virtual_tty(virtual_tty: VirtualTty) -> Self {
29 Self {
30 virtual_tty: Arc::new(Mutex::new(virtual_tty)),
31 master_fd: None,
32 slave_fd: None,
33 reader_thread: None,
34 }
35 }
36
37 pub fn get_virtual_tty(&self) -> Arc<Mutex<VirtualTty>> {
38 self.virtual_tty.clone()
39 }
40
41 pub fn get_snapshot(&self) -> String {
42 self.virtual_tty.lock().unwrap().get_snapshot()
43 }
44
45 pub fn get_size(&self) -> (usize, usize) {
46 self.virtual_tty.lock().unwrap().get_size()
47 }
48
49 fn create_pty(&mut self) -> io::Result<()> {
50 if self.master_fd.is_some() {
51 return Ok(());
52 }
53
54 let (width, height) = self.get_size();
55
56 unsafe {
57 let mut master: libc::c_int = 0;
58 let mut slave: libc::c_int = 0;
59
60 let result = libc::openpty(
62 &mut master,
63 &mut slave,
64 std::ptr::null_mut(),
65 std::ptr::null_mut(),
66 std::ptr::null_mut(),
67 );
68
69 if result != 0 {
70 return Err(io::Error::last_os_error());
71 }
72
73 let ws = winsize {
75 ws_row: height as u16,
76 ws_col: width as u16,
77 ws_xpixel: 0,
78 ws_ypixel: 0,
79 };
80
81 let _ = libc::ioctl(master, libc::TIOCSWINSZ, &ws);
82
83 self.master_fd = Some(master);
84 self.slave_fd = Some(slave);
85 }
86
87 Ok(())
88 }
89
90 fn start_reader_thread(&mut self) {
91 if self.reader_thread.is_some() {
92 return;
93 }
94
95 let master_fd = match self.master_fd {
96 Some(fd) => fd,
97 None => return,
98 };
99
100 let virtual_tty = self.virtual_tty.clone();
101
102 let reader_thread = thread::spawn(move || {
103 let mut read_buffer = [0u8; 4096];
104
105 loop {
106 let n = unsafe {
107 libc::read(
108 master_fd,
109 read_buffer.as_mut_ptr() as *mut libc::c_void,
110 read_buffer.len(),
111 )
112 };
113 match n {
114 0 => break, n if n > 0 => {
116 let data = String::from_utf8_lossy(&read_buffer[..n as usize]);
117 virtual_tty.lock().unwrap().stdout_write(&data);
118 }
119 _ => break,
120 }
121 }
122 });
123
124 self.reader_thread = Some(reader_thread);
125 }
126
127 pub fn spawn_command(&mut self, cmd: &mut Command) -> io::Result<Child> {
128 self.create_pty()?;
129 self.start_reader_thread();
130
131 let slave_fd = self
132 .slave_fd
133 .ok_or_else(|| io::Error::other("No slave PTY"))?;
134
135 let slave_stdin = unsafe { libc::dup(slave_fd) };
137 let slave_stdout = unsafe { libc::dup(slave_fd) };
138 let slave_stderr = unsafe { libc::dup(slave_fd) };
139
140 if slave_stdin < 0 || slave_stdout < 0 || slave_stderr < 0 {
141 return Err(io::Error::last_os_error());
142 }
143
144 unsafe {
145 use std::os::unix::io::FromRawFd;
146 use std::process::Stdio;
147
148 cmd.stdin(Stdio::from_raw_fd(slave_stdin))
149 .stdout(Stdio::from_raw_fd(slave_stdout))
150 .stderr(Stdio::from_raw_fd(slave_stderr));
151 }
152
153 cmd.spawn()
154 }
155
156 pub fn send_input(&mut self, input: &[u8]) -> io::Result<()> {
157 let master_fd = self
158 .master_fd
159 .ok_or_else(|| io::Error::other("No master PTY"))?;
160
161 let result = unsafe {
162 libc::write(
163 master_fd,
164 input.as_ptr() as *const libc::c_void,
165 input.len(),
166 )
167 };
168 if result < 0 {
169 Err(io::Error::last_os_error())
170 } else {
171 Ok(())
172 }
173 }
174
175 pub fn send_input_str(&mut self, input: &str) -> io::Result<()> {
177 self.send_input(input.as_bytes())
178 }
179
180 pub fn wait_for_completion(&mut self) {
182 if let Some(thread) = self.reader_thread.take() {
183 let _ = thread.join();
184 }
185 }
186}
187
188impl Drop for PtyAdapter {
189 fn drop(&mut self) {
190 if let Some(fd) = self.master_fd {
192 unsafe {
193 libc::close(fd);
194 }
195 }
196 if let Some(fd) = self.slave_fd {
197 unsafe {
198 libc::close(fd);
199 }
200 }
201
202 if let Some(thread) = self.reader_thread.take() {
204 let _ = thread.join();
205 }
206 }
207}