vtcode_core/tools/
pty.rs

1use std::collections::HashMap;
2use std::fs;
3use std::io::{Read, Write};
4use std::path::{Component, Path, PathBuf};
5#[cfg(unix)]
6use std::sync::mpsc;
7use std::sync::{Arc, Mutex};
8use std::thread::{self, JoinHandle};
9use std::time::Duration;
10#[cfg(unix)]
11use std::time::Instant;
12
13use anyhow::{Context, Result, anyhow};
14use portable_pty::{Child, CommandBuilder, MasterPty, PtySize, native_pty_system};
15use tracing::{debug, warn};
16use tui_term::vt100::Parser;
17
18use crate::config::PtyConfig;
19use crate::tools::types::VTCodePtySession;
20
21#[derive(Clone)]
22pub struct PtyManager {
23    workspace_root: PathBuf,
24    config: PtyConfig,
25    inner: Arc<PtyState>,
26}
27
28#[derive(Default)]
29struct PtyState {
30    sessions: Mutex<HashMap<String, PtySessionHandle>>,
31}
32
33struct PtySessionHandle {
34    master: Box<dyn MasterPty + Send>,
35    child: Mutex<Box<dyn Child + Send>>,
36    writer: Mutex<Option<Box<dyn Write + Send>>>,
37    parser: Arc<Mutex<Parser>>,
38    reader_thread: Mutex<Option<JoinHandle<()>>>,
39    metadata: VTCodePtySession,
40}
41
42impl PtySessionHandle {
43    fn snapshot_metadata(&self) -> VTCodePtySession {
44        let mut metadata = self.metadata.clone();
45        if let Ok(size) = self.master.get_size() {
46            metadata.rows = size.rows;
47            metadata.cols = size.cols;
48        }
49        if let Ok(parser) = self.parser.lock() {
50            metadata.screen_contents = Some(parser.screen().contents());
51        }
52        metadata
53    }
54}
55
56pub struct PtyCommandRequest {
57    pub command: Vec<String>,
58    pub working_dir: PathBuf,
59    pub timeout: Duration,
60    pub size: PtySize,
61}
62
63pub struct PtyCommandResult {
64    pub exit_code: i32,
65    pub output: String,
66    pub duration: Duration,
67    pub size: PtySize,
68}
69
70impl PtyManager {
71    pub fn new(workspace_root: PathBuf, config: PtyConfig) -> Self {
72        let resolved_root = workspace_root
73            .canonicalize()
74            .unwrap_or(workspace_root.clone());
75
76        Self {
77            workspace_root: resolved_root,
78            config,
79            inner: Arc::new(PtyState::default()),
80        }
81    }
82
83    pub fn config(&self) -> &PtyConfig {
84        &self.config
85    }
86
87    pub fn describe_working_dir(&self, path: &Path) -> String {
88        self.format_working_dir(path)
89    }
90
91    #[cfg(unix)]
92    pub async fn run_command(&self, request: PtyCommandRequest) -> Result<PtyCommandResult> {
93        if request.command.is_empty() {
94            return Err(anyhow!("PTY command cannot be empty"));
95        }
96
97        let mut command = request.command.clone();
98        let program = command.remove(0);
99        let args = command;
100        let timeout = clamp_timeout(request.timeout);
101        let work_dir = request.working_dir.clone();
102        let size = request.size;
103        let start = Instant::now();
104
105        let result = tokio::task::spawn_blocking(move || -> Result<PtyCommandResult> {
106            let timeout_duration = Duration::from_millis(timeout);
107            let mut builder = CommandBuilder::new(program.clone());
108            for arg in &args {
109                builder.arg(arg);
110            }
111            builder.cwd(&work_dir);
112            builder.env("TERM", "xterm-256color");
113            builder.env("COLUMNS", size.cols.to_string());
114            builder.env("LINES", size.rows.to_string());
115
116            let pty_system = native_pty_system();
117            let pair = pty_system
118                .openpty(size)
119                .context("failed to allocate PTY pair")?;
120
121            let mut child = pair
122                .slave
123                .spawn_command(builder)
124                .with_context(|| format!("failed to spawn PTY command '{program}'"))?;
125            let mut killer = child.clone_killer();
126            drop(pair.slave);
127
128            let reader = pair
129                .master
130                .try_clone_reader()
131                .context("failed to clone PTY reader")?;
132
133            let (wait_tx, wait_rx) = mpsc::channel();
134            let wait_thread = thread::spawn(move || {
135                let status = child.wait();
136                let _ = wait_tx.send(());
137                status
138            });
139
140            let reader_thread = thread::spawn(move || -> Result<Vec<u8>> {
141                let mut reader = reader;
142                let mut buffer = [0u8; 4096];
143                let mut collected = Vec::new();
144
145                loop {
146                    match reader.read(&mut buffer) {
147                        Ok(0) => break,
148                        Ok(bytes_read) => {
149                            collected.extend_from_slice(&buffer[..bytes_read]);
150                        }
151                        Err(error) if error.kind() == std::io::ErrorKind::Interrupted => continue,
152                        Err(error) => {
153                            return Err(error).context("failed to read PTY command output");
154                        }
155                    }
156                }
157
158                Ok(collected)
159            });
160
161            let wait_result = match wait_rx.recv_timeout(timeout_duration) {
162                Ok(()) => wait_thread
163                    .join()
164                    .map_err(|panic| anyhow!("PTY command wait thread panicked: {:?}", panic))?,
165                Err(mpsc::RecvTimeoutError::Timeout) => {
166                    killer
167                        .kill()
168                        .context("failed to terminate PTY command after timeout")?;
169
170                    let join_result = wait_thread.join().map_err(|panic| {
171                        anyhow!("PTY command wait thread panicked: {:?}", panic)
172                    })?;
173                    if let Err(error) = join_result {
174                        return Err(error)
175                            .context("failed to wait for PTY command to exit after timeout");
176                    }
177
178                    reader_thread
179                        .join()
180                        .map_err(|panic| {
181                            anyhow!("PTY command reader thread panicked: {:?}", panic)
182                        })?
183                        .context("failed to read PTY command output")?;
184
185                    return Err(anyhow!(
186                        "PTY command timed out after {} milliseconds",
187                        timeout
188                    ));
189                }
190                Err(mpsc::RecvTimeoutError::Disconnected) => {
191                    let join_result = wait_thread.join().map_err(|panic| {
192                        anyhow!("PTY command wait thread panicked: {:?}", panic)
193                    })?;
194                    if let Err(error) = join_result {
195                        return Err(error).context(
196                            "failed to wait for PTY command after wait channel disconnected",
197                        );
198                    }
199
200                    reader_thread
201                        .join()
202                        .map_err(|panic| {
203                            anyhow!("PTY command reader thread panicked: {:?}", panic)
204                        })?
205                        .context("failed to read PTY command output")?;
206
207                    return Err(anyhow!(
208                        "PTY command wait channel disconnected unexpectedly"
209                    ));
210                }
211            };
212
213            let status = wait_result.context("failed to wait for PTY command to exit")?;
214
215            let output_bytes = reader_thread
216                .join()
217                .map_err(|panic| anyhow!("PTY command reader thread panicked: {:?}", panic))?
218                .context("failed to read PTY command output")?;
219            let output = String::from_utf8_lossy(&output_bytes).to_string();
220            let exit_code = exit_status_code(status);
221
222            Ok(PtyCommandResult {
223                exit_code,
224                output,
225                duration: start.elapsed(),
226                size,
227            })
228        })
229        .await
230        .context("failed to join PTY command task")??;
231
232        Ok(result)
233    }
234
235    #[cfg(not(unix))]
236    pub async fn run_command(&self, request: PtyCommandRequest) -> Result<PtyCommandResult> {
237        if request.command.is_empty() {
238            return Err(anyhow!("PTY command cannot be empty"));
239        }
240
241        Err(anyhow!(
242            "PTY command execution is not supported on this platform"
243        ))
244    }
245
246    pub fn resolve_working_dir(&self, requested: Option<&str>) -> Result<PathBuf> {
247        let requested = match requested {
248            Some(dir) if !dir.trim().is_empty() => dir,
249            _ => return Ok(self.workspace_root.clone()),
250        };
251
252        let candidate = self.workspace_root.join(requested);
253        let normalized = normalize_path(&candidate);
254        if !normalized.starts_with(&self.workspace_root) {
255            return Err(anyhow!(
256                "Working directory '{}' escapes the workspace root",
257                candidate.display()
258            ));
259        }
260        let metadata = fs::metadata(&normalized).with_context(|| {
261            format!(
262                "Working directory '{}' does not exist",
263                normalized.display()
264            )
265        })?;
266        if !metadata.is_dir() {
267            return Err(anyhow!(
268                "Working directory '{}' is not a directory",
269                normalized.display()
270            ));
271        }
272        Ok(normalized)
273    }
274
275    pub fn create_session(
276        &self,
277        session_id: String,
278        command: Vec<String>,
279        working_dir: PathBuf,
280        size: PtySize,
281    ) -> Result<VTCodePtySession> {
282        if command.is_empty() {
283            return Err(anyhow!("PTY session command cannot be empty"));
284        }
285
286        let mut sessions = self
287            .inner
288            .sessions
289            .lock()
290            .expect("PTY session mutex poisoned");
291        if sessions.contains_key(&session_id) {
292            return Err(anyhow!("PTY session '{}' already exists", session_id));
293        }
294
295        let mut command_parts = command.clone();
296        let program = command_parts.remove(0);
297        let args = command_parts;
298
299        let pty_system = native_pty_system();
300        let pair = pty_system
301            .openpty(size)
302            .context("failed to allocate PTY pair")?;
303
304        let mut builder = CommandBuilder::new(program.clone());
305        for arg in &args {
306            builder.arg(arg);
307        }
308        builder.cwd(&working_dir);
309        builder.env("TERM", "xterm-256color");
310        builder.env("COLUMNS", size.cols.to_string());
311        builder.env("LINES", size.rows.to_string());
312
313        let child = pair
314            .slave
315            .spawn_command(builder)
316            .context("failed to spawn PTY session command")?;
317        drop(pair.slave);
318
319        let master = pair.master;
320        let mut reader = master
321            .try_clone_reader()
322            .context("failed to clone PTY reader")?;
323        let writer = master.take_writer().context("failed to take PTY writer")?;
324
325        let parser = Arc::new(Mutex::new(Parser::new(size.rows, size.cols, 0)));
326        let parser_clone = Arc::clone(&parser);
327        let session_name = session_id.clone();
328        let reader_thread = thread::Builder::new()
329            .name(format!("vtcode-pty-reader-{session_name}"))
330            .spawn(move || {
331                let mut buffer = [0u8; 4096];
332                loop {
333                    match reader.read(&mut buffer) {
334                        Ok(0) => {
335                            debug!("PTY session '{}' reader reached EOF", session_name);
336                            break;
337                        }
338                        Ok(bytes_read) => {
339                            if let Ok(mut parser) = parser_clone.lock() {
340                                parser.process(&buffer[..bytes_read]);
341                            }
342                        }
343                        Err(error) => {
344                            warn!("PTY session '{}' reader error: {}", session_name, error);
345                            break;
346                        }
347                    }
348                }
349            })
350            .context("failed to spawn PTY reader thread")?;
351
352        let metadata = VTCodePtySession {
353            id: session_id.clone(),
354            command: program,
355            args,
356            working_dir: Some(self.format_working_dir(&working_dir)),
357            rows: size.rows,
358            cols: size.cols,
359            screen_contents: None,
360        };
361
362        sessions.insert(
363            session_id.clone(),
364            PtySessionHandle {
365                master,
366                child: Mutex::new(child),
367                writer: Mutex::new(Some(writer)),
368                parser,
369                reader_thread: Mutex::new(Some(reader_thread)),
370                metadata: metadata.clone(),
371            },
372        );
373
374        Ok(metadata)
375    }
376
377    pub fn list_sessions(&self) -> Vec<VTCodePtySession> {
378        let sessions = self
379            .inner
380            .sessions
381            .lock()
382            .expect("PTY session mutex poisoned");
383        sessions
384            .values()
385            .map(PtySessionHandle::snapshot_metadata)
386            .collect()
387    }
388
389    pub fn close_session(&self, session_id: &str) -> Result<VTCodePtySession> {
390        let handle = {
391            let mut sessions = self
392                .inner
393                .sessions
394                .lock()
395                .expect("PTY session mutex poisoned");
396            sessions
397                .remove(session_id)
398                .ok_or_else(|| anyhow!("PTY session '{}' not found", session_id))?
399        };
400
401        if let Ok(mut writer_guard) = handle.writer.lock() {
402            if let Some(mut writer) = writer_guard.take() {
403                let _ = writer.write_all(b"exit\n");
404                let _ = writer.flush();
405            }
406        }
407
408        let mut child = handle.child.lock().expect("PTY child mutex poisoned");
409        if child
410            .try_wait()
411            .context("failed to poll PTY session status")?
412            .is_none()
413        {
414            child.kill().context("failed to terminate PTY session")?;
415            let _ = child.wait();
416        }
417
418        if let Ok(mut thread_guard) = handle.reader_thread.lock() {
419            if let Some(reader_thread) = thread_guard.take() {
420                if let Err(panic) = reader_thread.join() {
421                    warn!(
422                        "PTY session '{}' reader thread panicked: {:?}",
423                        session_id, panic
424                    );
425                }
426            }
427        }
428
429        Ok(handle.snapshot_metadata())
430    }
431
432    fn format_working_dir(&self, path: &Path) -> String {
433        match path.strip_prefix(&self.workspace_root) {
434            Ok(relative) if relative.as_os_str().is_empty() => ".".to_string(),
435            Ok(relative) => relative.to_string_lossy().replace("\\", "/"),
436            Err(_) => path.to_string_lossy().to_string(),
437        }
438    }
439}
440
441#[cfg(unix)]
442fn clamp_timeout(duration: Duration) -> u64 {
443    duration.as_millis().min(u64::MAX as u128) as u64
444}
445
446#[cfg(unix)]
447fn exit_status_code(status: portable_pty::ExitStatus) -> i32 {
448    if status.signal().is_some() {
449        -1
450    } else {
451        status.exit_code() as i32
452    }
453}
454
455fn normalize_path(path: &Path) -> PathBuf {
456    let mut normalized = PathBuf::new();
457    for component in path.components() {
458        match component {
459            Component::ParentDir => {
460                normalized.pop();
461            }
462            Component::CurDir => {}
463            Component::Prefix(prefix) => normalized.push(prefix.as_os_str()),
464            Component::RootDir => normalized.push(component.as_os_str()),
465            Component::Normal(part) => normalized.push(part),
466        }
467    }
468    normalized
469}