1use std::fmt;
9use std::io;
10use std::sync::Arc;
11use std::sync::Mutex as StdMutex;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use bytes::Bytes;
15use tokio::sync::{broadcast, mpsc, oneshot};
16use tokio::task::{AbortHandle, JoinHandle};
17
18pub trait ChildTerminator: Send + Sync {
22 fn kill(&mut self) -> io::Result<()>;
24}
25
26pub struct PtyHandles {
31 pub _slave: Option<Box<dyn Send>>,
33 pub _master: Box<dyn Send>,
35}
36
37impl fmt::Debug for PtyHandles {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 f.debug_struct("PtyHandles").finish()
40 }
41}
42
43pub struct ProcessHandle {
51 writer_tx: mpsc::Sender<Vec<u8>>,
52 output_tx: broadcast::Sender<Bytes>,
53 killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
54 reader_handle: StdMutex<Option<JoinHandle<()>>>,
55 reader_abort_handles: StdMutex<Vec<AbortHandle>>,
56 writer_handle: StdMutex<Option<JoinHandle<()>>>,
57 wait_handle: StdMutex<Option<JoinHandle<()>>>,
58 exit_status: Arc<AtomicBool>,
59 exit_code: Arc<StdMutex<Option<i32>>>,
60 _pty_handles: StdMutex<Option<PtyHandles>>,
62}
63
64impl fmt::Debug for ProcessHandle {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("ProcessHandle")
67 .field("has_exited", &self.has_exited())
68 .field("exit_code", &self.exit_code())
69 .finish()
70 }
71}
72
73impl ProcessHandle {
74 #[allow(clippy::too_many_arguments)]
76 pub fn new(
77 writer_tx: mpsc::Sender<Vec<u8>>,
78 output_tx: broadcast::Sender<Bytes>,
79 initial_output_rx: broadcast::Receiver<Bytes>,
80 killer: Box<dyn ChildTerminator>,
81 reader_handle: JoinHandle<()>,
82 reader_abort_handles: Vec<AbortHandle>,
83 writer_handle: JoinHandle<()>,
84 wait_handle: JoinHandle<()>,
85 exit_status: Arc<AtomicBool>,
86 exit_code: Arc<StdMutex<Option<i32>>>,
87 pty_handles: Option<PtyHandles>,
88 ) -> (Self, broadcast::Receiver<Bytes>) {
89 (
90 Self {
91 writer_tx,
92 output_tx,
93 killer: StdMutex::new(Some(killer)),
94 reader_handle: StdMutex::new(Some(reader_handle)),
95 reader_abort_handles: StdMutex::new(reader_abort_handles),
96 writer_handle: StdMutex::new(Some(writer_handle)),
97 wait_handle: StdMutex::new(Some(wait_handle)),
98 exit_status,
99 exit_code,
100 _pty_handles: StdMutex::new(pty_handles),
101 },
102 initial_output_rx,
103 )
104 }
105
106 pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
114 self.writer_tx.clone()
115 }
116
117 pub fn output_receiver(&self) -> broadcast::Receiver<Bytes> {
122 self.output_tx.subscribe()
123 }
124
125 pub fn has_exited(&self) -> bool {
127 self.exit_status.load(Ordering::SeqCst)
128 }
129
130 pub fn exit_code(&self) -> Option<i32> {
132 self.exit_code.lock().ok().and_then(|guard| *guard)
133 }
134
135 pub fn terminate(&self) {
139 self.terminate_internal();
140 }
141
142 fn terminate_internal(&self) {
144 if let Ok(mut killer_opt) = self.killer.lock()
146 && let Some(mut killer) = killer_opt.take()
147 {
148 let _ = killer.kill();
149 }
150
151 self.abort_tasks();
152 }
153
154 fn abort_tasks(&self) {
156 if let Ok(mut h) = self.reader_handle.lock()
158 && let Some(handle) = h.take()
159 {
160 handle.abort();
161 }
162
163 if let Ok(mut handles) = self.reader_abort_handles.lock() {
165 for handle in handles.drain(..) {
166 handle.abort();
167 }
168 }
169
170 if let Ok(mut h) = self.writer_handle.lock()
172 && let Some(handle) = h.take()
173 {
174 handle.abort();
175 }
176
177 if let Ok(mut h) = self.wait_handle.lock()
179 && let Some(handle) = h.take()
180 {
181 handle.abort();
182 }
183 }
184
185 pub fn is_running(&self) -> bool {
187 !self.has_exited() && !self.is_writer_closed()
188 }
189
190 pub async fn write(
194 &self,
195 bytes: impl Into<Vec<u8>>,
196 ) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
197 self.writer_tx.send(bytes.into()).await
198 }
199
200 pub fn is_writer_closed(&self) -> bool {
202 self.writer_tx.is_closed()
203 }
204}
205
206impl Drop for ProcessHandle {
207 fn drop(&mut self) {
208 self.terminate_internal();
209 }
210}
211
212#[derive(Debug)]
216pub struct SpawnedProcess {
217 pub session: ProcessHandle,
219 pub output_rx: broadcast::Receiver<Bytes>,
221 pub exit_rx: oneshot::Receiver<i32>,
223}
224
225impl SpawnedProcess {
226 pub async fn wait_with_output(self, timeout_ms: u64) -> (Vec<u8>, i32) {
230 collect_output_until_exit(self.output_rx, self.exit_rx, timeout_ms).await
231 }
232}
233
234pub async fn collect_output_until_exit(
238 mut output_rx: broadcast::Receiver<Bytes>,
239 exit_rx: oneshot::Receiver<i32>,
240 timeout_ms: u64,
241) -> (Vec<u8>, i32) {
242 let mut collected = Vec::new();
243 let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
244 tokio::pin!(exit_rx);
245
246 loop {
247 tokio::select! {
248 res = output_rx.recv() => {
249 if let Ok(chunk) = res {
250 collected.extend_from_slice(&chunk);
251 }
252 }
253 res = &mut exit_rx => {
254 let code = res.unwrap_or(-1);
255 let quiet = tokio::time::Duration::from_millis(50);
257 let max_deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
258
259 while tokio::time::Instant::now() < max_deadline {
260 match tokio::time::timeout(quiet, output_rx.recv()).await {
261 Ok(Ok(chunk)) => collected.extend_from_slice(&chunk),
262 Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue,
263 Ok(Err(broadcast::error::RecvError::Closed)) => break,
264 Err(_) => break, }
266 }
267 return (collected, code);
268 }
269 _ = tokio::time::sleep_until(deadline) => {
270 return (collected, -1);
271 }
272 }
273 }
274}
275
276pub type ExecCommandSession = ProcessHandle;
278
279pub type SpawnedPty = SpawnedProcess;
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 struct NoopTerminator;
287 impl ChildTerminator for NoopTerminator {
288 fn kill(&mut self) -> io::Result<()> {
289 Ok(())
290 }
291 }
292
293 #[tokio::test]
294 async fn test_process_handle_debug() {
295 let exit_status = Arc::new(AtomicBool::new(false));
297 let exit_code = Arc::new(StdMutex::new(None));
298
299 let (writer_tx, _) = mpsc::channel(1);
300 let (output_tx, initial_rx) = broadcast::channel(1);
301
302 let (handle, _) = ProcessHandle::new(
303 writer_tx,
304 output_tx,
305 initial_rx,
306 Box::new(NoopTerminator),
307 tokio::spawn(async {}),
308 vec![],
309 tokio::spawn(async {}),
310 tokio::spawn(async {}),
311 exit_status,
312 exit_code,
313 None,
314 );
315
316 let debug_str = format!("{handle:?}");
317 assert!(debug_str.contains("ProcessHandle"));
318 }
319
320 #[tokio::test]
321 async fn test_has_exited() {
322 let exit_status = Arc::new(AtomicBool::new(false));
323 let exit_code = Arc::new(StdMutex::new(None));
324
325 let (writer_tx, _) = mpsc::channel(1);
326 let (output_tx, initial_rx) = broadcast::channel(1);
327
328 let (handle, _) = ProcessHandle::new(
329 writer_tx,
330 output_tx,
331 initial_rx,
332 Box::new(NoopTerminator),
333 tokio::spawn(async {}),
334 vec![],
335 tokio::spawn(async {}),
336 tokio::spawn(async {}),
337 Arc::clone(&exit_status),
338 exit_code,
339 None,
340 );
341
342 assert!(!handle.has_exited());
343 exit_status.store(true, Ordering::SeqCst);
344 assert!(handle.has_exited());
345 }
346}