Skip to main content

vtcode_bash_runner/
process.rs

1//! Unified process handle types for PTY and pipe backends.
2//!
3//! This module provides abstractions for interacting with spawned processes
4//! regardless of whether they use a PTY or regular pipes.
5//!
6//! Inspired by codex-rs/utils/pty process handle patterns.
7
8use std::fmt;
9use std::io;
10use std::sync::Arc;
11use std::sync::Mutex as StdMutex;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use bytes::Bytes;
15use tokio::sync::{broadcast, mpsc, oneshot};
16use tokio::task::{AbortHandle, JoinHandle};
17
18/// Trait for process termination strategies.
19///
20/// Different backends (PTY vs pipe) may need different termination approaches.
21pub trait ChildTerminator: Send + Sync {
22    /// Kill the child process.
23    fn kill(&mut self) -> io::Result<()>;
24}
25
26/// Optional PTY-specific handles that must be preserved.
27///
28/// For PTY processes, the slave handle must be kept alive because the process
29/// will receive SIGHUP if it's closed.
30pub struct PtyHandles {
31    /// The slave PTY handle (kept alive to prevent SIGHUP).
32    pub _slave: Option<Box<dyn Send>>,
33    /// The master PTY handle.
34    pub _master: Box<dyn Send>,
35}
36
37impl fmt::Debug for PtyHandles {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        f.debug_struct("PtyHandles").finish()
40    }
41}
42
43/// Handle for driving an interactive or non-interactive process.
44///
45/// This provides a unified interface for both PTY and pipe-based processes:
46/// - Write to stdin via `writer_sender()`
47/// - Read merged stdout/stderr via `output_receiver()`
48/// - Check exit status via `has_exited()` and `exit_code()`
49/// - Clean up via `terminate()`
50pub struct ProcessHandle {
51    writer_tx: mpsc::Sender<Vec<u8>>,
52    output_tx: broadcast::Sender<Bytes>,
53    killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
54    reader_handle: StdMutex<Option<JoinHandle<()>>>,
55    reader_abort_handles: StdMutex<Vec<AbortHandle>>,
56    writer_handle: StdMutex<Option<JoinHandle<()>>>,
57    wait_handle: StdMutex<Option<JoinHandle<()>>>,
58    exit_status: Arc<AtomicBool>,
59    exit_code: Arc<StdMutex<Option<i32>>>,
60    // PTY handles must be preserved to prevent the process from receiving Control+C
61    _pty_handles: StdMutex<Option<PtyHandles>>,
62}
63
64impl fmt::Debug for ProcessHandle {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("ProcessHandle")
67            .field("has_exited", &self.has_exited())
68            .field("exit_code", &self.exit_code())
69            .finish()
70    }
71}
72
73impl ProcessHandle {
74    /// Create a new process handle with all required components.
75    #[allow(clippy::too_many_arguments)]
76    pub fn new(
77        writer_tx: mpsc::Sender<Vec<u8>>,
78        output_tx: broadcast::Sender<Bytes>,
79        initial_output_rx: broadcast::Receiver<Bytes>,
80        killer: Box<dyn ChildTerminator>,
81        reader_handle: JoinHandle<()>,
82        reader_abort_handles: Vec<AbortHandle>,
83        writer_handle: JoinHandle<()>,
84        wait_handle: JoinHandle<()>,
85        exit_status: Arc<AtomicBool>,
86        exit_code: Arc<StdMutex<Option<i32>>>,
87        pty_handles: Option<PtyHandles>,
88    ) -> (Self, broadcast::Receiver<Bytes>) {
89        (
90            Self {
91                writer_tx,
92                output_tx,
93                killer: StdMutex::new(Some(killer)),
94                reader_handle: StdMutex::new(Some(reader_handle)),
95                reader_abort_handles: StdMutex::new(reader_abort_handles),
96                writer_handle: StdMutex::new(Some(writer_handle)),
97                wait_handle: StdMutex::new(Some(wait_handle)),
98                exit_status,
99                exit_code,
100                _pty_handles: StdMutex::new(pty_handles),
101            },
102            initial_output_rx,
103        )
104    }
105
106    /// Returns a channel sender for writing raw bytes to the child stdin.
107    ///
108    /// # Example
109    /// ```ignore
110    /// let writer = handle.writer_sender();
111    /// writer.send(b"input\n".to_vec()).await?;
112    /// ```
113    pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
114        self.writer_tx.clone()
115    }
116
117    /// Returns a broadcast receiver that yields stdout/stderr chunks.
118    ///
119    /// Multiple receivers can be created; each receives all output from the
120    /// point of subscription.
121    pub fn output_receiver(&self) -> broadcast::Receiver<Bytes> {
122        self.output_tx.subscribe()
123    }
124
125    /// True if the child process has exited.
126    pub fn has_exited(&self) -> bool {
127        self.exit_status.load(Ordering::SeqCst)
128    }
129
130    /// Returns the exit code if the process has exited.
131    pub fn exit_code(&self) -> Option<i32> {
132        self.exit_code.lock().ok().and_then(|guard| *guard)
133    }
134
135    /// True once the stdout/stderr reader task has drained the child streams.
136    pub fn is_output_drained(&self) -> bool {
137        self.reader_handle
138            .lock()
139            .ok()
140            .and_then(|guard| guard.as_ref().map(JoinHandle::is_finished))
141            .unwrap_or(true)
142    }
143
144    /// Attempts to kill the child and abort helper tasks.
145    ///
146    /// This is idempotent and safe to call multiple times.
147    pub fn terminate(&self) {
148        self.terminate_internal();
149    }
150
151    /// Internal termination that aborts all tasks.
152    fn terminate_internal(&self) {
153        // Kill the child process
154        if let Ok(mut killer_opt) = self.killer.lock()
155            && let Some(mut killer) = killer_opt.take()
156        {
157            let _ = killer.kill();
158        }
159
160        self.abort_tasks();
161    }
162
163    /// Abort all background tasks associated with this process.
164    fn abort_tasks(&self) {
165        // Abort reader handle
166        if let Ok(mut h) = self.reader_handle.lock()
167            && let Some(handle) = h.take()
168        {
169            handle.abort();
170        }
171
172        // Abort individual reader abort handles
173        if let Ok(mut handles) = self.reader_abort_handles.lock() {
174            for handle in handles.drain(..) {
175                handle.abort();
176            }
177        }
178
179        // Abort writer handle
180        if let Ok(mut h) = self.writer_handle.lock()
181            && let Some(handle) = h.take()
182        {
183            handle.abort();
184        }
185
186        // Abort wait handle
187        if let Ok(mut h) = self.wait_handle.lock()
188            && let Some(handle) = h.take()
189        {
190            handle.abort();
191        }
192    }
193
194    /// Check if the process is still running.
195    pub fn is_running(&self) -> bool {
196        !self.has_exited() && !self.is_writer_closed()
197    }
198
199    /// Send bytes to the process stdin.
200    ///
201    /// Returns an error if the stdin channel is closed.
202    pub async fn write(
203        &self,
204        bytes: impl Into<Vec<u8>>,
205    ) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
206        self.writer_tx.send(bytes.into()).await
207    }
208
209    /// Check if the writer channel is closed.
210    pub fn is_writer_closed(&self) -> bool {
211        self.writer_tx.is_closed()
212    }
213}
214
215impl Drop for ProcessHandle {
216    fn drop(&mut self) {
217        self.terminate_internal();
218    }
219}
220
221/// Return value from spawn helpers (PTY or pipe).
222///
223/// Bundles the process handle with receivers for output and exit notification.
224#[derive(Debug)]
225pub struct SpawnedProcess {
226    /// Handle for interacting with the process.
227    pub session: ProcessHandle,
228    /// Receiver for stdout/stderr output chunks.
229    pub output_rx: broadcast::Receiver<Bytes>,
230    /// Receiver for exit code (receives once when process exits).
231    pub exit_rx: oneshot::Receiver<i32>,
232}
233
234impl SpawnedProcess {
235    /// Convenience method to wait for the process to exit and collect output.
236    ///
237    /// Returns (collected_output, exit_code).
238    pub async fn wait_with_output(self, timeout_ms: u64) -> (Vec<u8>, i32) {
239        collect_output_until_exit(self.output_rx, self.exit_rx, timeout_ms).await
240    }
241}
242
243/// Collect output from a process until it exits or times out.
244///
245/// This is useful for tests and simple use cases where you want all output.
246pub async fn collect_output_until_exit(
247    mut output_rx: broadcast::Receiver<Bytes>,
248    exit_rx: oneshot::Receiver<i32>,
249    timeout_ms: u64,
250) -> (Vec<u8>, i32) {
251    let mut collected = Vec::new();
252    let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
253    tokio::pin!(exit_rx);
254
255    loop {
256        tokio::select! {
257            res = output_rx.recv() => {
258                if let Ok(chunk) = res {
259                    collected.extend_from_slice(&chunk);
260                }
261            }
262            res = &mut exit_rx => {
263                let code = res.unwrap_or(-1);
264                // Drain remaining output briefly after exit
265                let quiet = tokio::time::Duration::from_millis(50);
266                let max_deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
267
268                while tokio::time::Instant::now() < max_deadline {
269                    match tokio::time::timeout(quiet, output_rx.recv()).await {
270                        Ok(Ok(chunk)) => collected.extend_from_slice(&chunk),
271                        Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue,
272                        Ok(Err(broadcast::error::RecvError::Closed)) => break,
273                        Err(_) => break, // Timeout - quiet period reached
274                    }
275                }
276                return (collected, code);
277            }
278            _ = tokio::time::sleep_until(deadline) => {
279                return (collected, -1);
280            }
281        }
282    }
283}
284
285/// Backwards-compatible alias for ProcessHandle.
286pub type ExecCommandSession = ProcessHandle;
287
288/// Backwards-compatible alias for SpawnedProcess.
289pub type SpawnedPty = SpawnedProcess;
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    struct NoopTerminator;
296    impl ChildTerminator for NoopTerminator {
297        fn kill(&mut self) -> io::Result<()> {
298            Ok(())
299        }
300    }
301
302    #[tokio::test]
303    async fn test_process_handle_debug() {
304        // Just verify Debug impl doesn't panic
305        let exit_status = Arc::new(AtomicBool::new(false));
306        let exit_code = Arc::new(StdMutex::new(None));
307
308        let (writer_tx, _) = mpsc::channel(1);
309        let (output_tx, initial_rx) = broadcast::channel(1);
310
311        let (handle, _) = ProcessHandle::new(
312            writer_tx,
313            output_tx,
314            initial_rx,
315            Box::new(NoopTerminator),
316            tokio::spawn(async {}),
317            vec![],
318            tokio::spawn(async {}),
319            tokio::spawn(async {}),
320            exit_status,
321            exit_code,
322            None,
323        );
324
325        let debug_str = format!("{handle:?}");
326        assert!(debug_str.contains("ProcessHandle"));
327    }
328
329    #[tokio::test]
330    async fn test_has_exited() {
331        let exit_status = Arc::new(AtomicBool::new(false));
332        let exit_code = Arc::new(StdMutex::new(None));
333
334        let (writer_tx, _) = mpsc::channel(1);
335        let (output_tx, initial_rx) = broadcast::channel(1);
336
337        let (handle, _) = ProcessHandle::new(
338            writer_tx,
339            output_tx,
340            initial_rx,
341            Box::new(NoopTerminator),
342            tokio::spawn(async {}),
343            vec![],
344            tokio::spawn(async {}),
345            tokio::spawn(async {}),
346            Arc::clone(&exit_status),
347            exit_code,
348            None,
349        );
350
351        assert!(!handle.has_exited());
352        exit_status.store(true, Ordering::SeqCst);
353        assert!(handle.has_exited());
354    }
355}