Skip to main content

winx_code_agent/state/
pty.rs

1//! Real PTY implementation using portable-pty
2//!
3//! This module provides a true pseudo-terminal interface for interactive
4//! shell sessions, enabling proper handling of:
5//! - ANSI escape sequences and colors
6//! - Interactive programs (sudo, vim, less, etc.)
7//! - Terminal resize events
8//! - Job control signals (Ctrl+C, Ctrl+Z, etc.)
9
10use anyhow::{anyhow, Context, Result};
11use portable_pty::{native_pty_system, Child, CommandBuilder, MasterPty, PtySize};
12use std::collections::hash_map::DefaultHasher;
13use std::collections::VecDeque;
14use std::hash::{Hash, Hasher};
15use std::io::{Read, Write};
16use std::path::Path;
17use std::sync::mpsc::{self, TryRecvError};
18use std::sync::Arc;
19use std::thread;
20use std::time::{Duration, Instant};
21use tokio::sync::Mutex;
22use tracing::{debug, info, warn};
23
24/// Default terminal dimensions (columns x rows)
25pub const DEFAULT_COLS: u16 = 200;
26pub const DEFAULT_ROWS: u16 = 50;
27
28/// Maximum output buffer size to prevent memory issues
29const MAX_OUTPUT_SIZE: usize = 1_000_000;
30
31/// How many fully-formed lines to keep in the per-shell ringbuffer. Callers can
32/// ask for at most this many lines of historical context via
33/// `StatusCheck.scrollback_lines`.
34pub const RING_BUFFER_LINES: usize = 2_000;
35
36/// WCGW-style prompt pattern for command completion detection
37const WCGW_PROMPT_PATTERN: &str = "◉";
38const WCGW_PROMPT_END: &str = "──➤";
39
40/// Real PTY-based interactive shell
41///
42/// Uses portable-pty for true pseudo-terminal functionality,
43/// enabling proper handling of interactive programs like sudo, vim, etc.
44pub struct PtyShell {
45    /// The PTY master handle for resize operations
46    master: Box<dyn MasterPty + Send>,
47    /// Child process running the shell
48    child: Box<dyn Child + Send + Sync>,
49    /// Writer for PTY input (taken from master)
50    writer: Box<dyn Write + Send>,
51    /// Channel receiver for output from reader thread
52    output_rx: mpsc::Receiver<String>,
53    /// Current terminal size
54    size: PtySize,
55    /// Last command executed
56    pub last_command: String,
57    /// Accumulated output buffer
58    pub output_buffer: String,
59    /// Whether a command is currently running
60    pub command_running: bool,
61    /// Maximum output size before truncation
62    max_output_size: usize,
63    /// Flag for output truncation
64    pub output_truncated: bool,
65    /// Rolling buffer of fully-emitted lines for opt-in scrollback. The newest
66    /// line is at the back; capped at `RING_BUFFER_LINES`.
67    pub line_ring: VecDeque<String>,
68    /// Carries the unterminated tail across reads so partial lines aren't
69    /// double-counted when more bytes arrive.
70    line_ring_partial: String,
71    /// Hash of the last rendered output we shipped to the caller. Used by the
72    /// delta path in `status_check` to elide repeats when the screen is idle.
73    pub last_returned_hash: Option<u64>,
74}
75
76impl std::fmt::Debug for PtyShell {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("PtyShell")
79            .field("size", &format!("{}x{}", self.size.cols, self.size.rows))
80            .field("last_command", &self.last_command)
81            .field("command_running", &self.command_running)
82            .field("output_truncated", &self.output_truncated)
83            .field("output_buffer_len", &self.output_buffer.len())
84            .finish_non_exhaustive()
85    }
86}
87
88impl PtyShell {
89    /// Create a new PTY shell session
90    ///
91    /// # Arguments
92    /// * `initial_dir` - Starting directory for the shell
93    /// * `restricted_mode` - Whether to use bash restricted mode (-r)
94    ///
95    /// # Returns
96    /// A new `PtyShell` instance with an active bash session
97    pub fn new(initial_dir: &Path, restricted_mode: bool) -> Result<Self> {
98        info!(
99            "Creating new PTY shell (restricted: {}) in {}",
100            restricted_mode,
101            initial_dir.display()
102        );
103
104        // Initialize the native PTY system
105        let pty_system = native_pty_system();
106
107        // Configure terminal size
108        let size =
109            PtySize { rows: DEFAULT_ROWS, cols: DEFAULT_COLS, pixel_width: 0, pixel_height: 0 };
110
111        // Open the PTY pair (master + slave)
112        let pair = pty_system.openpty(size).context("Failed to open PTY pair")?;
113
114        // Build the command
115        let mut cmd = CommandBuilder::new("bash");
116        if restricted_mode {
117            cmd.arg("-r");
118        }
119
120        // Set up environment for proper terminal behavior
121        cmd.env("TERM", "xterm-256color");
122        cmd.env("COLORTERM", "truecolor");
123        cmd.env("PAGER", "cat");
124        cmd.env("GIT_PAGER", "cat");
125        cmd.env("COLUMNS", DEFAULT_COLS.to_string());
126        cmd.env("ROWS", DEFAULT_ROWS.to_string());
127        // WCGW-style prompt for command completion detection
128        // Note: removed \r\e[2K which was erasing the prompt before it could be detected
129        cmd.env("PROMPT_COMMAND", r#"printf "◉ %s──➤ " "$PWD""#);
130        cmd.cwd(initial_dir);
131
132        // Spawn bash in the PTY slave
133        let child = pair.slave.spawn_command(cmd).context("Failed to spawn bash in PTY")?;
134
135        // Get reader and writer from master
136        let mut reader = pair.master.try_clone_reader().context("Failed to clone PTY reader")?;
137        let writer = pair.master.take_writer().context("Failed to take PTY writer")?;
138
139        // Create channel for output from reader thread
140        let (output_tx, output_rx) = mpsc::channel::<String>();
141
142        // Spawn a background thread to read from the PTY
143        // This prevents blocking the main thread
144        thread::spawn(move || {
145            let mut buf = [0u8; 4096];
146            loop {
147                match reader.read(&mut buf) {
148                    Ok(0) => {
149                        // EOF - PTY closed
150                        break;
151                    }
152                    Ok(n) => {
153                        let chunk = String::from_utf8_lossy(&buf[..n]).to_string();
154                        if output_tx.send(chunk).is_err() {
155                            // Receiver dropped, exit thread
156                            break;
157                        }
158                    }
159                    Err(e) => {
160                        debug!("PTY reader thread error: {}", e);
161                        break;
162                    }
163                }
164            }
165            debug!("PTY reader thread exiting");
166        });
167
168        // Create the shell instance
169        let mut shell = Self {
170            master: pair.master,
171            child,
172            writer,
173            output_rx,
174            size,
175            last_command: String::new(),
176            output_buffer: String::new(),
177            command_running: false,
178            max_output_size: MAX_OUTPUT_SIZE,
179            output_truncated: false,
180            line_ring: VecDeque::with_capacity(RING_BUFFER_LINES),
181            line_ring_partial: String::new(),
182            last_returned_hash: None,
183        };
184
185        // Initialize the shell with WCGW-style prompt
186        shell.initialize_prompt()?;
187
188        debug!("PTY shell created successfully");
189        Ok(shell)
190    }
191
192    /// Initialize the shell prompt for WCGW compatibility
193    fn initialize_prompt(&mut self) -> Result<()> {
194        // Set up the dynamic prompt - matches WCGW Python PROMPT_STATEMENT
195        // Note: removed \r\e[2K which was erasing the prompt before it could be detected
196        let prompt_statement =
197            r#"export GIT_PAGER=cat PAGER=cat PROMPT_COMMAND='printf "◉ %s──➤ " "$PWD"'"#;
198
199        self.write_command(prompt_statement)?;
200
201        // Wait for prompt to be ready
202        std::thread::sleep(Duration::from_millis(100));
203        let _ = self.drain_output();
204
205        Ok(())
206    }
207
208    /// Write a command to the PTY
209    fn write_command(&mut self, command: &str) -> Result<()> {
210        // Commands in PTY need \r\n for proper terminal behavior
211        let cmd_with_newline = format!("{command}\n");
212        self.writer.write_all(cmd_with_newline.as_bytes()).context("Failed to write to PTY")?;
213        self.writer.flush().context("Failed to flush PTY")?;
214        Ok(())
215    }
216
217    /// Drain any pending output from the PTY channel
218    fn drain_output(&mut self) -> String {
219        let mut output = String::new();
220        let deadline = Instant::now() + Duration::from_millis(200);
221
222        // Drain all available output from the channel
223        while Instant::now() < deadline {
224            match self.output_rx.try_recv() {
225                Ok(chunk) => {
226                    output.push_str(&chunk);
227
228                    // Prevent runaway reads
229                    if output.len() > self.max_output_size {
230                        self.output_truncated = true;
231                        break;
232                    }
233                }
234                Err(TryRecvError::Empty) => {
235                    // No more data, wait briefly for more
236                    thread::sleep(Duration::from_millis(10));
237                }
238                Err(TryRecvError::Disconnected) => {
239                    // Reader thread died
240                    break;
241                }
242            }
243        }
244
245        output
246    }
247
248    /// Drain any pending output and, if a previous command still seems to be
249    /// running, send a Ctrl-C to flush it. Mirrors wcgw's `clear_to_run` so a
250    /// new command never inherits stale prompt fragments or a half-typed line.
251    ///
252    /// Returns `true` if the shell looks idle (prompt seen), `false` if it
253    /// still wouldn't yield after the Ctrl-C — caller may want to reset.
254    pub fn clear_to_run(&mut self, max_wait_secs: f32) -> Result<bool> {
255        // Drain whatever is in the channel without blocking. Use the existing
256        // read_output to also catch the prompt fingerprint.
257        let (_, complete) = self.read_output(max_wait_secs.min(0.5))?;
258        if complete {
259            return Ok(true);
260        }
261
262        // Something is still running — interrupt it.
263        debug!("clear_to_run: prompt not seen, sending Ctrl+C");
264        self.send_interrupt()?;
265
266        // Re-drain after the interrupt so the next command starts on a clean prompt.
267        let (_, drained) = self.read_output(max_wait_secs)?;
268        Ok(drained)
269    }
270
271    /// Send a command to the shell and start reading output
272    pub fn send_command(&mut self, command: &str) -> Result<()> {
273        debug!("PTY sending command: {}", command);
274
275        // Clear previous state
276        self.output_buffer.clear();
277        self.output_truncated = false;
278        self.last_command = command.to_string();
279        self.command_running = true;
280        // A new command means the next status_check should return whatever
281        // shows up — drop the dedup hash so we don't elide the first response.
282        self.last_returned_hash = None;
283
284        // Write the command
285        self.write_command(command)?;
286
287        Ok(())
288    }
289
290    /// Push freshly-arrived bytes through the line-oriented ringbuffer so
291    /// callers can request bounded scrollback later.
292    fn ingest_into_ring(&mut self, chunk: &str) {
293        let combined = if self.line_ring_partial.is_empty() {
294            chunk.to_string()
295        } else {
296            let mut s = std::mem::take(&mut self.line_ring_partial);
297            s.push_str(chunk);
298            s
299        };
300
301        let mut last_nl_end: Option<usize> = None;
302        for (idx, ch) in combined.char_indices() {
303            if ch == '\n' {
304                let end = idx + ch.len_utf8();
305                let start = last_nl_end.unwrap_or(0);
306                let line = combined[start..idx].trim_end_matches('\r').to_string();
307                if self.line_ring.len() == RING_BUFFER_LINES {
308                    self.line_ring.pop_front();
309                }
310                self.line_ring.push_back(line);
311                last_nl_end = Some(end);
312            }
313        }
314
315        if let Some(end) = last_nl_end {
316            self.line_ring_partial = combined[end..].to_string();
317        } else {
318            self.line_ring_partial = combined;
319        }
320    }
321
322    /// Return up to `lines` recent lines from the ringbuffer, oldest first.
323    /// Includes any in-flight partial line.
324    pub fn collect_scrollback(&self, lines: usize) -> String {
325        if lines == 0 {
326            return String::new();
327        }
328        let start = self.line_ring.len().saturating_sub(lines);
329        let mut out = String::new();
330        for line in self.line_ring.iter().skip(start) {
331            out.push_str(line);
332            out.push('\n');
333        }
334        if !self.line_ring_partial.is_empty() {
335            out.push_str(&self.line_ring_partial);
336        }
337        out
338    }
339
340    /// Hash arbitrary rendered output into a u64 dedup key.
341    pub fn fingerprint(text: &str) -> u64 {
342        let mut hasher = DefaultHasher::new();
343        text.hash(&mut hasher);
344        hasher.finish()
345    }
346
347    /// Read output from the PTY with timeout
348    ///
349    /// Returns (output, `is_complete`) tuple where `is_complete` indicates
350    /// whether the command has finished (prompt detected)
351    pub fn read_output(&mut self, timeout_secs: f32) -> Result<(String, bool)> {
352        let timeout = Duration::from_secs_f32(timeout_secs.clamp(0.1, 60.0));
353        let start = Instant::now();
354        let mut complete = false;
355        let mut no_data_count = 0;
356        let mut prompt_detected_at: Option<Instant> = None;
357
358        while start.elapsed() < timeout {
359            match self.output_rx.try_recv() {
360                Ok(chunk) => {
361                    self.output_buffer.push_str(&chunk);
362                    self.ingest_into_ring(&chunk);
363                    no_data_count = 0;
364
365                    // Check for WCGW prompt indicating command completion
366                    if prompt_detected_at.is_none()
367                        && (Self::check_prompt_complete(&chunk)
368                            || Self::check_prompt_complete(&self.output_buffer))
369                    {
370                        prompt_detected_at = Some(Instant::now());
371                        debug!("Prompt detected, draining remaining output...");
372                    }
373
374                    // Truncate if too large
375                    if self.output_buffer.len() > self.max_output_size {
376                        self.output_truncated = true;
377                        let truncate_msg = "\n(...output truncated...)\n";
378                        let keep_size = self.max_output_size / 2;
379                        self.output_buffer = format!(
380                            "{}{}",
381                            truncate_msg,
382                            &self.output_buffer[self.output_buffer.len() - keep_size..]
383                        );
384                    }
385                }
386                Err(TryRecvError::Empty) => {
387                    // No data available, wait briefly
388                    thread::sleep(Duration::from_millis(10));
389                    no_data_count += 1;
390
391                    // If prompt was detected, check if we've drained long enough
392                    if let Some(detected_time) = prompt_detected_at {
393                        // Wait 100ms after prompt detection to capture any trailing output
394                        if detected_time.elapsed() > Duration::from_millis(100) {
395                            complete = true;
396                            debug!("Command completed - prompt detected and drained");
397                            break;
398                        }
399                    } else if no_data_count > 10 && Self::check_prompt_complete(&self.output_buffer)
400                    {
401                        // Prompt detected during empty reads
402                        prompt_detected_at = Some(Instant::now());
403                        debug!("Prompt detected after wait, draining...");
404                    }
405                }
406                Err(TryRecvError::Disconnected) => {
407                    // Reader thread died - PTY closed
408                    warn!("PTY reader disconnected");
409                    complete = true;
410                    break;
411                }
412            }
413        }
414
415        if complete || prompt_detected_at.is_some() {
416            self.command_running = false;
417            complete = true;
418        }
419
420        Ok((self.output_buffer.clone(), complete))
421    }
422
423    /// Check if the output contains the WCGW-style prompt
424    fn check_prompt_complete(text: &str) -> bool {
425        // Look for the WCGW prompt pattern: ◉ /path──➤
426        text.contains(WCGW_PROMPT_PATTERN) && text.contains(WCGW_PROMPT_END)
427    }
428
429    /// Send Ctrl+C (interrupt) to the PTY
430    pub fn send_interrupt(&mut self) -> Result<()> {
431        debug!("PTY sending Ctrl+C");
432        self.writer
433            .write_all(&[0x03]) // ASCII ETX (Ctrl+C)
434            .context("Failed to send Ctrl+C")?;
435        self.writer.flush()?;
436        Ok(())
437    }
438
439    /// Send Ctrl+D (EOF) to the PTY
440    pub fn send_eof(&mut self) -> Result<()> {
441        debug!("PTY sending Ctrl+D");
442        self.writer
443            .write_all(&[0x04]) // ASCII EOT (Ctrl+D)
444            .context("Failed to send Ctrl+D")?;
445        self.writer.flush()?;
446        Ok(())
447    }
448
449    /// Send Ctrl+Z (suspend) to the PTY
450    pub fn send_suspend(&mut self) -> Result<()> {
451        debug!("PTY sending Ctrl+Z");
452        self.writer
453            .write_all(&[0x1A]) // ASCII SUB (Ctrl+Z)
454            .context("Failed to send Ctrl+Z")?;
455        self.writer.flush()?;
456        Ok(())
457    }
458
459    /// Send text directly to the PTY (for interactive input)
460    pub fn send_text(&mut self, text: &str) -> Result<()> {
461        debug!("PTY sending text: {:?}", text);
462        self.send_bytes(text.as_bytes()).context("Failed to send text")?;
463        Ok(())
464    }
465
466    /// Send raw bytes directly to the PTY.
467    pub fn send_bytes(&mut self, bytes: &[u8]) -> Result<()> {
468        self.writer.write_all(bytes).context("Failed to send bytes")?;
469        self.writer.flush()?;
470        Ok(())
471    }
472
473    /// Send a special key sequence
474    pub fn send_special_key(&mut self, key: &str) -> Result<()> {
475        let bytes: &[u8] = match key {
476            "Enter" => b"\r",
477            "Tab" => b"\t",
478            "Backspace" => b"\x7F",
479            "Escape" => b"\x1B",
480            "Up" | "KeyUp" => b"\x1B[A",
481            "Down" | "KeyDown" => b"\x1B[B",
482            "Right" | "KeyRight" => b"\x1B[C",
483            "Left" | "KeyLeft" => b"\x1B[D",
484            "Home" => b"\x1B[H",
485            "End" => b"\x1B[F",
486            "PageUp" => b"\x1B[5~",
487            "PageDown" => b"\x1B[6~",
488            "Delete" => b"\x1B[3~",
489            "Insert" => b"\x1B[2~",
490            "CtrlC" | "Ctrl-C" => b"\x03",
491            "CtrlD" | "Ctrl-D" => b"\x04",
492            "CtrlZ" | "Ctrl-Z" => b"\x1A",
493            "CtrlL" | "Ctrl-L" => b"\x0C",
494            _ => return Err(anyhow!("Unknown special key: {key}")),
495        };
496
497        debug!("PTY sending special key: {} ({:?})", key, bytes);
498        self.send_bytes(bytes)?;
499        Ok(())
500    }
501
502    /// Resize the terminal
503    pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
504        debug!("PTY resizing to {}x{}", cols, rows);
505
506        let new_size = PtySize { rows, cols, pixel_width: 0, pixel_height: 0 };
507
508        self.master.resize(new_size).context("Failed to resize PTY")?;
509
510        self.size = new_size;
511        Ok(())
512    }
513
514    /// Get current terminal size
515    pub fn get_size(&self) -> (u16, u16) {
516        (self.size.cols, self.size.rows)
517    }
518
519    /// Check if the shell is still alive
520    pub fn is_alive(&mut self) -> bool {
521        self.child.try_wait().is_ok_and(|status| status.is_none())
522    }
523}
524
525/// Thread-safe wrapper for `PtyShell`
526pub type SharedPtyShell = Arc<Mutex<Option<PtyShell>>>;
527
528/// Create a new shared PTY shell
529pub fn create_shared_pty(initial_dir: &Path, restricted_mode: bool) -> Result<SharedPtyShell> {
530    let shell = PtyShell::new(initial_dir, restricted_mode)?;
531    Ok(Arc::new(Mutex::new(Some(shell))))
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537    use tempfile::TempDir;
538
539    #[test]
540    fn test_pty_shell_creation() -> Result<()> {
541        let temp_dir = TempDir::new()?;
542        let result = PtyShell::new(temp_dir.path(), false);
543        assert!(result.is_ok(), "Failed to create PTY shell: {:?}", result.err());
544        Ok(())
545    }
546
547    #[test]
548    fn test_pty_shell_echo() -> Result<()> {
549        let temp_dir = TempDir::new()?;
550        let mut shell = PtyShell::new(temp_dir.path(), false)?;
551
552        shell.send_command("echo 'hello pty'")?;
553        let (output, _complete) = shell.read_output(2.0)?;
554
555        assert!(output.contains("hello pty"), "Output should contain 'hello pty': {output}");
556        Ok(())
557    }
558
559    #[test]
560    fn test_pty_shell_pwd() -> Result<()> {
561        let temp_dir = TempDir::new()?;
562        let mut shell = PtyShell::new(temp_dir.path(), false)?;
563
564        // Simply verify shell responds to pwd command
565        // Use single quotes like echo test for consistency
566        shell.send_command("pwd && echo 'pwd_done'")?;
567        let (output, _complete) = shell.read_output(2.0)?;
568
569        // Verify the echo marker appears (proves command executed)
570        assert!(output.contains("pwd_done"), "Output should contain 'pwd_done': {output}");
571        Ok(())
572    }
573
574    #[test]
575    fn test_pty_resize() -> Result<()> {
576        let temp_dir = TempDir::new()?;
577        let mut shell = PtyShell::new(temp_dir.path(), false)?;
578
579        let result = shell.resize(120, 40);
580        assert!(result.is_ok());
581
582        let (cols, rows) = shell.get_size();
583        assert_eq!(cols, 120);
584        assert_eq!(rows, 40);
585        Ok(())
586    }
587}