Skip to main content

tmai_core/wrap/
runner.rs

1//! PTY runner for wrapping AI agents
2//!
3//! Creates a PTY and runs the specified command, proxying I/O while monitoring state.
4
5use anyhow::{Context, Result};
6use nix::sys::signal::{self, Signal};
7use nix::unistd::Pid;
8use portable_pty::{native_pty_system, CommandBuilder, PtySize};
9use std::io::{Read, Write};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use std::thread::{self, JoinHandle};
13use std::time::{Duration, Instant};
14
15use crate::config::ExfilDetectionSettings;
16use crate::ipc::client::IpcClient;
17use crate::ipc::protocol::WrapState;
18use crate::wrap::analyzer::Analyzer;
19use crate::wrap::exfil_detector::ExfilDetector;
20
21/// PTY runner configuration
22pub struct PtyRunnerConfig {
23    /// Command to run
24    pub command: String,
25    /// Arguments
26    pub args: Vec<String>,
27    /// Unique ID for state file (e.g., tmux pane ID or UUID)
28    pub id: String,
29    /// Initial PTY size
30    pub rows: u16,
31    /// Initial PTY columns
32    pub cols: u16,
33    /// External transmission detection settings
34    pub exfil_detection: ExfilDetectionSettings,
35}
36
37impl Default for PtyRunnerConfig {
38    fn default() -> Self {
39        Self {
40            command: String::new(),
41            args: Vec::new(),
42            id: uuid::Uuid::new_v4().to_string(),
43            rows: 24,
44            cols: 80,
45            exfil_detection: ExfilDetectionSettings::default(),
46        }
47    }
48}
49
50/// PTY runner that wraps an AI agent
51pub struct PtyRunner {
52    config: PtyRunnerConfig,
53}
54
55impl PtyRunner {
56    /// Create a new PTY runner
57    pub fn new(config: PtyRunnerConfig) -> Self {
58        Self { config }
59    }
60
61    /// Run the wrapped command
62    pub fn run(self) -> Result<i32> {
63        // Get terminal size from the current terminal
64        let (rows, cols) = get_terminal_size().unwrap_or((self.config.rows, self.config.cols));
65
66        // Create PTY
67        let pty_system = native_pty_system();
68        let pair = pty_system
69            .openpty(PtySize {
70                rows,
71                cols,
72                pixel_width: 0,
73                pixel_height: 0,
74            })
75            .context("Failed to open PTY")?;
76
77        // Build command
78        let mut cmd = CommandBuilder::new(&self.config.command);
79        cmd.args(&self.config.args);
80
81        // Set working directory to current directory
82        if let Ok(cwd) = std::env::current_dir() {
83            cmd.cwd(cwd);
84        }
85
86        // Spawn child process
87        let mut child = pair
88            .slave
89            .spawn_command(cmd)
90            .context("Failed to spawn command")?;
91        let child_pid = child.process_id().unwrap_or(0);
92
93        tracing::debug!("Spawned {} with PID {}", self.config.command, child_pid);
94
95        // Create analyzer
96        let analyzer = Arc::new(parking_lot::Mutex::new(Analyzer::new(child_pid)));
97
98        // Create exfil detector
99        let exfil_detector = Arc::new(ExfilDetector::new(&self.config.exfil_detection, child_pid));
100
101        // Flag for shutdown
102        let running = Arc::new(AtomicBool::new(true));
103
104        // Get PTY master for read/write
105        let mut master_reader = pair
106            .master
107            .try_clone_reader()
108            .context("Failed to clone PTY reader")?;
109        let master_writer = pair
110            .master
111            .take_writer()
112            .context("Failed to take PTY writer")?;
113
114        // Wrap master_writer in Arc<Mutex> for sharing between input_thread and IPC client
115        let master_writer_shared: Arc<parking_lot::Mutex<Box<dyn Write + Send>>> =
116            Arc::new(parking_lot::Mutex::new(master_writer));
117
118        // Extract team info before starting IPC client
119        // (each lock() must be in its own statement to avoid deadlock)
120        let team_name = analyzer.lock().team_name().cloned();
121        let team_member_name = analyzer.lock().team_member_name().cloned();
122        let is_team_lead = analyzer.lock().is_team_lead();
123
124        // Start IPC client for communication with tmai parent
125        let ipc_client = IpcClient::start(
126            self.config.id.clone(),
127            child_pid,
128            team_name,
129            team_member_name,
130            is_team_lead,
131            running.clone(),
132            master_writer_shared.clone(),
133            analyzer.clone(),
134        );
135
136        // Thread: Read from PTY master -> write to stdout
137        let analyzer_out = analyzer.clone();
138        let exfil_detector_out = exfil_detector.clone();
139        let running_out = running.clone();
140        let output_thread = thread::spawn(move || {
141            let mut stdout = std::io::stdout();
142            let mut buf = [0u8; 4096];
143
144            while running_out.load(Ordering::Relaxed) {
145                match master_reader.read(&mut buf) {
146                    Ok(0) => break, // EOF
147                    Ok(n) => {
148                        // Write to stdout
149                        if stdout.write_all(&buf[..n]).is_err() {
150                            break;
151                        }
152                        let _ = stdout.flush();
153
154                        // Process for state detection (convert to string, ignoring invalid UTF-8)
155                        if let Ok(s) = std::str::from_utf8(&buf[..n]) {
156                            analyzer_out.lock().process_output(s);
157
158                            // Check for external transmission commands
159                            exfil_detector_out.check_output(s);
160                        }
161                    }
162                    Err(e) => {
163                        if e.kind() != std::io::ErrorKind::WouldBlock {
164                            tracing::debug!("PTY read error: {}", e);
165                            break;
166                        }
167                    }
168                }
169            }
170        });
171
172        // Thread: Read from stdin -> write to PTY master
173        let analyzer_in = analyzer.clone();
174        let running_in = running.clone();
175        let writer_for_input = master_writer_shared;
176        let input_thread = thread::spawn(move || {
177            let stdin = std::io::stdin();
178            let mut stdin = stdin.lock();
179            let mut buf = [0u8; 1024];
180
181            while running_in.load(Ordering::Relaxed) {
182                match stdin.read(&mut buf) {
183                    Ok(0) => break, // EOF
184                    Ok(n) => {
185                        // Write to PTY via shared writer
186                        {
187                            let mut writer = writer_for_input.lock();
188                            if writer.write_all(&buf[..n]).is_err() {
189                                break;
190                            }
191                            let _ = writer.flush();
192                        }
193
194                        // Process for state detection
195                        if let Ok(s) = std::str::from_utf8(&buf[..n]) {
196                            analyzer_in.lock().process_input(s);
197                        }
198                    }
199                    Err(e) => {
200                        if e.kind() != std::io::ErrorKind::WouldBlock {
201                            tracing::debug!("stdin read error: {}", e);
202                            break;
203                        }
204                    }
205                }
206            }
207        });
208
209        // Thread: Periodic state update via IPC with change detection
210        let analyzer_state = analyzer.clone();
211        let running_state = running.clone();
212        let state_thread = thread::spawn(move || {
213            let mut last_state: Option<WrapState> = None;
214
215            while running_state.load(Ordering::Relaxed) {
216                thread::sleep(Duration::from_millis(100));
217
218                let state = analyzer_state.lock().get_state();
219
220                // Only send if state has changed
221                let should_send = match &last_state {
222                    None => true,
223                    Some(prev) => !states_equal(prev, &state),
224                };
225
226                if should_send {
227                    ipc_client.send_state(state.clone());
228                    last_state = Some(state);
229                }
230            }
231        });
232
233        // Thread: Periodic terminal size check for resize
234        // Instead of using SIGWINCH signal handler (which requires unsafe TLS access),
235        // we poll for terminal size changes. This is simpler and avoids undefined behavior.
236        let running_resize = running.clone();
237        let pty_master = pair.master;
238        let resize_thread = thread::spawn(move || {
239            let mut last_size: Option<(u16, u16)> = get_terminal_size();
240
241            while running_resize.load(Ordering::Relaxed) {
242                thread::sleep(Duration::from_millis(100));
243
244                let current_size = get_terminal_size();
245                if current_size != last_size {
246                    if let Some((rows, cols)) = current_size {
247                        let _ = pty_master.resize(PtySize {
248                            rows,
249                            cols,
250                            pixel_width: 0,
251                            pixel_height: 0,
252                        });
253                    }
254                    last_size = current_size;
255                }
256            }
257        });
258
259        // Wait for child to exit
260        let exit_status = child.wait().context("Failed to wait for child")?;
261
262        // Signal threads to stop (IPC connection will close automatically)
263        running.store(false, Ordering::Relaxed);
264
265        // Wait for threads with timeout to avoid hanging on blocked stdin
266        join_thread_with_timeout(output_thread, Duration::from_secs(1));
267        join_thread_with_timeout(input_thread, Duration::from_secs(1));
268        join_thread_with_timeout(state_thread, Duration::from_secs(1));
269        join_thread_with_timeout(resize_thread, Duration::from_secs(1));
270
271        // Return exit code
272        Ok(exit_status.exit_code() as i32)
273    }
274}
275
276/// Compare two WrapState instances for equality (ignoring timestamps)
277fn states_equal(a: &WrapState, b: &WrapState) -> bool {
278    a.status == b.status
279        && a.approval_type == b.approval_type
280        && a.details == b.details
281        && a.choices == b.choices
282        && a.multi_select == b.multi_select
283        && a.cursor_position == b.cursor_position
284        && a.pid == b.pid
285        && a.pane_id == b.pane_id
286}
287
288/// Join a thread with a timeout, abandoning it if it doesn't finish in time
289fn join_thread_with_timeout<T>(handle: JoinHandle<T>, timeout: Duration) {
290    let start = Instant::now();
291    loop {
292        if handle.is_finished() {
293            let _ = handle.join();
294            return;
295        }
296        if start.elapsed() >= timeout {
297            tracing::debug!("Thread join timed out, abandoning thread");
298            // Thread will be leaked but we can't block forever
299            return;
300        }
301        thread::sleep(Duration::from_millis(10));
302    }
303}
304
305/// Get current terminal size
306fn get_terminal_size() -> Option<(u16, u16)> {
307    use nix::libc;
308
309    // Try to get size from STDOUT
310    let fd = libc::STDOUT_FILENO;
311    let mut size: libc::winsize = unsafe { std::mem::zeroed() };
312
313    let result = unsafe { libc::ioctl(fd, libc::TIOCGWINSZ, &mut size) };
314
315    if result == 0 && size.ws_row > 0 && size.ws_col > 0 {
316        Some((size.ws_row, size.ws_col))
317    } else {
318        None
319    }
320}
321
322/// Forward a signal to the child process
323pub fn forward_signal_to_child(child_pid: u32, sig: Signal) -> Result<()> {
324    if child_pid > 0 {
325        signal::kill(Pid::from_raw(child_pid as i32), sig).context("Failed to forward signal")?;
326    }
327    Ok(())
328}
329
330/// Parse command string into command and arguments
331///
332/// Splits the input string by whitespace. Does not handle quoted strings
333/// or shell escaping - for complex commands, pass them as pre-parsed arguments.
334pub fn parse_command(cmd_str: &str) -> (String, Vec<String>) {
335    let parts: Vec<&str> = cmd_str.split_whitespace().collect();
336    if parts.is_empty() {
337        return (String::new(), Vec::new());
338    }
339
340    let command = parts[0].to_string();
341    let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
342
343    (command, args)
344}
345
346/// Determine the pane ID from environment or generate one
347pub fn get_pane_id() -> String {
348    // Try TMUX_PANE environment variable
349    if let Ok(pane) = std::env::var("TMUX_PANE") {
350        // TMUX_PANE is like "%0", "%1", etc.
351        // We want just the number
352        return pane.trim_start_matches('%').to_string();
353    }
354
355    // Fall back to UUID
356    uuid::Uuid::new_v4().to_string()
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_parse_command_simple() {
365        let (cmd, args) = parse_command("claude");
366        assert_eq!(cmd, "claude");
367        assert!(args.is_empty());
368    }
369
370    #[test]
371    fn test_parse_command_with_args() {
372        let (cmd, args) = parse_command("claude --debug --config test.toml");
373        assert_eq!(cmd, "claude");
374        assert_eq!(args, vec!["--debug", "--config", "test.toml"]);
375    }
376
377    #[test]
378    fn test_parse_command_empty() {
379        let (cmd, args) = parse_command("");
380        assert!(cmd.is_empty());
381        assert!(args.is_empty());
382    }
383
384    #[test]
385    fn test_get_pane_id_fallback() {
386        // When not in tmux, should return UUID
387        let id = get_pane_id();
388        assert!(!id.is_empty());
389    }
390}