1use anyhow::{Context, Result};
6use nix::sys::signal::{self, Signal};
7use nix::unistd::Pid;
8use portable_pty::{native_pty_system, CommandBuilder, PtySize};
9use std::io::{Read, Write};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use std::thread::{self, JoinHandle};
13use std::time::{Duration, Instant};
14
15use crate::config::ExfilDetectionSettings;
16use crate::ipc::client::IpcClient;
17use crate::ipc::protocol::WrapState;
18use crate::wrap::analyzer::Analyzer;
19use crate::wrap::exfil_detector::ExfilDetector;
20
21pub struct PtyRunnerConfig {
23 pub command: String,
25 pub args: Vec<String>,
27 pub id: String,
29 pub rows: u16,
31 pub cols: u16,
33 pub exfil_detection: ExfilDetectionSettings,
35}
36
37impl Default for PtyRunnerConfig {
38 fn default() -> Self {
39 Self {
40 command: String::new(),
41 args: Vec::new(),
42 id: uuid::Uuid::new_v4().to_string(),
43 rows: 24,
44 cols: 80,
45 exfil_detection: ExfilDetectionSettings::default(),
46 }
47 }
48}
49
50pub struct PtyRunner {
52 config: PtyRunnerConfig,
53}
54
55impl PtyRunner {
56 pub fn new(config: PtyRunnerConfig) -> Self {
58 Self { config }
59 }
60
61 pub fn run(self) -> Result<i32> {
63 let (rows, cols) = get_terminal_size().unwrap_or((self.config.rows, self.config.cols));
65
66 let pty_system = native_pty_system();
68 let pair = pty_system
69 .openpty(PtySize {
70 rows,
71 cols,
72 pixel_width: 0,
73 pixel_height: 0,
74 })
75 .context("Failed to open PTY")?;
76
77 let mut cmd = CommandBuilder::new(&self.config.command);
79 cmd.args(&self.config.args);
80
81 if let Ok(cwd) = std::env::current_dir() {
83 cmd.cwd(cwd);
84 }
85
86 let mut child = pair
88 .slave
89 .spawn_command(cmd)
90 .context("Failed to spawn command")?;
91 let child_pid = child.process_id().unwrap_or(0);
92
93 tracing::debug!("Spawned {} with PID {}", self.config.command, child_pid);
94
95 let analyzer = Arc::new(parking_lot::Mutex::new(Analyzer::new(child_pid)));
97
98 let exfil_detector = Arc::new(ExfilDetector::new(&self.config.exfil_detection, child_pid));
100
101 let running = Arc::new(AtomicBool::new(true));
103
104 let mut master_reader = pair
106 .master
107 .try_clone_reader()
108 .context("Failed to clone PTY reader")?;
109 let master_writer = pair
110 .master
111 .take_writer()
112 .context("Failed to take PTY writer")?;
113
114 let master_writer_shared: Arc<parking_lot::Mutex<Box<dyn Write + Send>>> =
116 Arc::new(parking_lot::Mutex::new(master_writer));
117
118 let team_name = analyzer.lock().team_name().cloned();
121 let team_member_name = analyzer.lock().team_member_name().cloned();
122 let is_team_lead = analyzer.lock().is_team_lead();
123
124 let ipc_client = IpcClient::start(
126 self.config.id.clone(),
127 child_pid,
128 team_name,
129 team_member_name,
130 is_team_lead,
131 running.clone(),
132 master_writer_shared.clone(),
133 analyzer.clone(),
134 );
135
136 let analyzer_out = analyzer.clone();
138 let exfil_detector_out = exfil_detector.clone();
139 let running_out = running.clone();
140 let output_thread = thread::spawn(move || {
141 let mut stdout = std::io::stdout();
142 let mut buf = [0u8; 4096];
143
144 while running_out.load(Ordering::Relaxed) {
145 match master_reader.read(&mut buf) {
146 Ok(0) => break, Ok(n) => {
148 if stdout.write_all(&buf[..n]).is_err() {
150 break;
151 }
152 let _ = stdout.flush();
153
154 if let Ok(s) = std::str::from_utf8(&buf[..n]) {
156 analyzer_out.lock().process_output(s);
157
158 exfil_detector_out.check_output(s);
160 }
161 }
162 Err(e) => {
163 if e.kind() != std::io::ErrorKind::WouldBlock {
164 tracing::debug!("PTY read error: {}", e);
165 break;
166 }
167 }
168 }
169 }
170 });
171
172 let analyzer_in = analyzer.clone();
174 let running_in = running.clone();
175 let writer_for_input = master_writer_shared;
176 let input_thread = thread::spawn(move || {
177 let stdin = std::io::stdin();
178 let mut stdin = stdin.lock();
179 let mut buf = [0u8; 1024];
180
181 while running_in.load(Ordering::Relaxed) {
182 match stdin.read(&mut buf) {
183 Ok(0) => break, Ok(n) => {
185 {
187 let mut writer = writer_for_input.lock();
188 if writer.write_all(&buf[..n]).is_err() {
189 break;
190 }
191 let _ = writer.flush();
192 }
193
194 if let Ok(s) = std::str::from_utf8(&buf[..n]) {
196 analyzer_in.lock().process_input(s);
197 }
198 }
199 Err(e) => {
200 if e.kind() != std::io::ErrorKind::WouldBlock {
201 tracing::debug!("stdin read error: {}", e);
202 break;
203 }
204 }
205 }
206 }
207 });
208
209 let analyzer_state = analyzer.clone();
211 let running_state = running.clone();
212 let state_thread = thread::spawn(move || {
213 let mut last_state: Option<WrapState> = None;
214
215 while running_state.load(Ordering::Relaxed) {
216 thread::sleep(Duration::from_millis(100));
217
218 let state = analyzer_state.lock().get_state();
219
220 let should_send = match &last_state {
222 None => true,
223 Some(prev) => !states_equal(prev, &state),
224 };
225
226 if should_send {
227 ipc_client.send_state(state.clone());
228 last_state = Some(state);
229 }
230 }
231 });
232
233 let running_resize = running.clone();
237 let pty_master = pair.master;
238 let resize_thread = thread::spawn(move || {
239 let mut last_size: Option<(u16, u16)> = get_terminal_size();
240
241 while running_resize.load(Ordering::Relaxed) {
242 thread::sleep(Duration::from_millis(100));
243
244 let current_size = get_terminal_size();
245 if current_size != last_size {
246 if let Some((rows, cols)) = current_size {
247 let _ = pty_master.resize(PtySize {
248 rows,
249 cols,
250 pixel_width: 0,
251 pixel_height: 0,
252 });
253 }
254 last_size = current_size;
255 }
256 }
257 });
258
259 let exit_status = child.wait().context("Failed to wait for child")?;
261
262 running.store(false, Ordering::Relaxed);
264
265 join_thread_with_timeout(output_thread, Duration::from_secs(1));
267 join_thread_with_timeout(input_thread, Duration::from_secs(1));
268 join_thread_with_timeout(state_thread, Duration::from_secs(1));
269 join_thread_with_timeout(resize_thread, Duration::from_secs(1));
270
271 Ok(exit_status.exit_code() as i32)
273 }
274}
275
276fn states_equal(a: &WrapState, b: &WrapState) -> bool {
278 a.status == b.status
279 && a.approval_type == b.approval_type
280 && a.details == b.details
281 && a.choices == b.choices
282 && a.multi_select == b.multi_select
283 && a.cursor_position == b.cursor_position
284 && a.pid == b.pid
285 && a.pane_id == b.pane_id
286}
287
288fn join_thread_with_timeout<T>(handle: JoinHandle<T>, timeout: Duration) {
290 let start = Instant::now();
291 loop {
292 if handle.is_finished() {
293 let _ = handle.join();
294 return;
295 }
296 if start.elapsed() >= timeout {
297 tracing::debug!("Thread join timed out, abandoning thread");
298 return;
300 }
301 thread::sleep(Duration::from_millis(10));
302 }
303}
304
305fn get_terminal_size() -> Option<(u16, u16)> {
307 use nix::libc;
308
309 let fd = libc::STDOUT_FILENO;
311 let mut size: libc::winsize = unsafe { std::mem::zeroed() };
312
313 let result = unsafe { libc::ioctl(fd, libc::TIOCGWINSZ, &mut size) };
314
315 if result == 0 && size.ws_row > 0 && size.ws_col > 0 {
316 Some((size.ws_row, size.ws_col))
317 } else {
318 None
319 }
320}
321
322pub fn forward_signal_to_child(child_pid: u32, sig: Signal) -> Result<()> {
324 if child_pid > 0 {
325 signal::kill(Pid::from_raw(child_pid as i32), sig).context("Failed to forward signal")?;
326 }
327 Ok(())
328}
329
330pub fn parse_command(cmd_str: &str) -> (String, Vec<String>) {
335 let parts: Vec<&str> = cmd_str.split_whitespace().collect();
336 if parts.is_empty() {
337 return (String::new(), Vec::new());
338 }
339
340 let command = parts[0].to_string();
341 let args: Vec<String> = parts[1..].iter().map(|s| s.to_string()).collect();
342
343 (command, args)
344}
345
346pub fn get_pane_id() -> String {
348 if let Ok(pane) = std::env::var("TMUX_PANE") {
350 return pane.trim_start_matches('%').to_string();
353 }
354
355 uuid::Uuid::new_v4().to_string()
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_parse_command_simple() {
365 let (cmd, args) = parse_command("claude");
366 assert_eq!(cmd, "claude");
367 assert!(args.is_empty());
368 }
369
370 #[test]
371 fn test_parse_command_with_args() {
372 let (cmd, args) = parse_command("claude --debug --config test.toml");
373 assert_eq!(cmd, "claude");
374 assert_eq!(args, vec!["--debug", "--config", "test.toml"]);
375 }
376
377 #[test]
378 fn test_parse_command_empty() {
379 let (cmd, args) = parse_command("");
380 assert!(cmd.is_empty());
381 assert!(args.is_empty());
382 }
383
384 #[test]
385 fn test_get_pane_id_fallback() {
386 let id = get_pane_id();
388 assert!(!id.is_empty());
389 }
390}