1use std::fmt;
9use std::io;
10use std::sync::Arc;
11use std::sync::Mutex as StdMutex;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use bytes::Bytes;
15use tokio::sync::{broadcast, mpsc, oneshot};
16use tokio::task::{AbortHandle, JoinHandle};
17
18pub trait ChildTerminator: Send + Sync {
22 fn kill(&mut self) -> io::Result<()>;
24}
25
26pub struct PtyHandles {
31 pub _slave: Option<Box<dyn Send>>,
33 pub _master: Box<dyn Send>,
35}
36
37impl fmt::Debug for PtyHandles {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 f.debug_struct("PtyHandles").finish()
40 }
41}
42
43pub struct ProcessHandle {
51 writer_tx: mpsc::Sender<Vec<u8>>,
52 output_tx: broadcast::Sender<Bytes>,
53 killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
54 reader_handle: StdMutex<Option<JoinHandle<()>>>,
55 reader_abort_handles: StdMutex<Vec<AbortHandle>>,
56 writer_handle: StdMutex<Option<JoinHandle<()>>>,
57 wait_handle: StdMutex<Option<JoinHandle<()>>>,
58 exit_status: Arc<AtomicBool>,
59 exit_code: Arc<StdMutex<Option<i32>>>,
60 _pty_handles: StdMutex<Option<PtyHandles>>,
62}
63
64impl fmt::Debug for ProcessHandle {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("ProcessHandle")
67 .field("has_exited", &self.has_exited())
68 .field("exit_code", &self.exit_code())
69 .finish()
70 }
71}
72
73impl ProcessHandle {
74 #[allow(clippy::too_many_arguments)]
76 pub fn new(
77 writer_tx: mpsc::Sender<Vec<u8>>,
78 output_tx: broadcast::Sender<Bytes>,
79 initial_output_rx: broadcast::Receiver<Bytes>,
80 killer: Box<dyn ChildTerminator>,
81 reader_handle: JoinHandle<()>,
82 reader_abort_handles: Vec<AbortHandle>,
83 writer_handle: JoinHandle<()>,
84 wait_handle: JoinHandle<()>,
85 exit_status: Arc<AtomicBool>,
86 exit_code: Arc<StdMutex<Option<i32>>>,
87 pty_handles: Option<PtyHandles>,
88 ) -> (Self, broadcast::Receiver<Bytes>) {
89 (
90 Self {
91 writer_tx,
92 output_tx,
93 killer: StdMutex::new(Some(killer)),
94 reader_handle: StdMutex::new(Some(reader_handle)),
95 reader_abort_handles: StdMutex::new(reader_abort_handles),
96 writer_handle: StdMutex::new(Some(writer_handle)),
97 wait_handle: StdMutex::new(Some(wait_handle)),
98 exit_status,
99 exit_code,
100 _pty_handles: StdMutex::new(pty_handles),
101 },
102 initial_output_rx,
103 )
104 }
105
106 pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
114 self.writer_tx.clone()
115 }
116
117 pub fn output_receiver(&self) -> broadcast::Receiver<Bytes> {
122 self.output_tx.subscribe()
123 }
124
125 pub fn has_exited(&self) -> bool {
127 self.exit_status.load(Ordering::SeqCst)
128 }
129
130 pub fn exit_code(&self) -> Option<i32> {
132 self.exit_code.lock().ok().and_then(|guard| *guard)
133 }
134
135 pub fn is_output_drained(&self) -> bool {
137 self.reader_handle
138 .lock()
139 .ok()
140 .and_then(|guard| guard.as_ref().map(JoinHandle::is_finished))
141 .unwrap_or(true)
142 }
143
144 pub fn terminate(&self) {
148 self.terminate_internal();
149 }
150
151 fn terminate_internal(&self) {
153 if let Ok(mut killer_opt) = self.killer.lock()
155 && let Some(mut killer) = killer_opt.take()
156 {
157 let _ = killer.kill();
158 }
159
160 self.abort_tasks();
161 }
162
163 fn abort_tasks(&self) {
165 if let Ok(mut h) = self.reader_handle.lock()
167 && let Some(handle) = h.take()
168 {
169 handle.abort();
170 }
171
172 if let Ok(mut handles) = self.reader_abort_handles.lock() {
174 for handle in handles.drain(..) {
175 handle.abort();
176 }
177 }
178
179 if let Ok(mut h) = self.writer_handle.lock()
181 && let Some(handle) = h.take()
182 {
183 handle.abort();
184 }
185
186 if let Ok(mut h) = self.wait_handle.lock()
188 && let Some(handle) = h.take()
189 {
190 handle.abort();
191 }
192 }
193
194 pub fn is_running(&self) -> bool {
196 !self.has_exited() && !self.is_writer_closed()
197 }
198
199 pub async fn write(
203 &self,
204 bytes: impl Into<Vec<u8>>,
205 ) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
206 self.writer_tx.send(bytes.into()).await
207 }
208
209 pub fn is_writer_closed(&self) -> bool {
211 self.writer_tx.is_closed()
212 }
213}
214
215impl Drop for ProcessHandle {
216 fn drop(&mut self) {
217 self.terminate_internal();
218 }
219}
220
221#[derive(Debug)]
225pub struct SpawnedProcess {
226 pub session: ProcessHandle,
228 pub output_rx: broadcast::Receiver<Bytes>,
230 pub exit_rx: oneshot::Receiver<i32>,
232}
233
234impl SpawnedProcess {
235 pub async fn wait_with_output(self, timeout_ms: u64) -> (Vec<u8>, i32) {
239 collect_output_until_exit(self.output_rx, self.exit_rx, timeout_ms).await
240 }
241}
242
243pub async fn collect_output_until_exit(
247 mut output_rx: broadcast::Receiver<Bytes>,
248 exit_rx: oneshot::Receiver<i32>,
249 timeout_ms: u64,
250) -> (Vec<u8>, i32) {
251 let mut collected = Vec::new();
252 let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
253 tokio::pin!(exit_rx);
254
255 loop {
256 tokio::select! {
257 res = output_rx.recv() => {
258 if let Ok(chunk) = res {
259 collected.extend_from_slice(&chunk);
260 }
261 }
262 res = &mut exit_rx => {
263 let code = res.unwrap_or(-1);
264 let quiet = tokio::time::Duration::from_millis(50);
266 let max_deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
267
268 while tokio::time::Instant::now() < max_deadline {
269 match tokio::time::timeout(quiet, output_rx.recv()).await {
270 Ok(Ok(chunk)) => collected.extend_from_slice(&chunk),
271 Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue,
272 Ok(Err(broadcast::error::RecvError::Closed)) => break,
273 Err(_) => break, }
275 }
276 return (collected, code);
277 }
278 _ = tokio::time::sleep_until(deadline) => {
279 return (collected, -1);
280 }
281 }
282 }
283}
284
285pub type ExecCommandSession = ProcessHandle;
287
288pub type SpawnedPty = SpawnedProcess;
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 struct NoopTerminator;
296 impl ChildTerminator for NoopTerminator {
297 fn kill(&mut self) -> io::Result<()> {
298 Ok(())
299 }
300 }
301
302 #[tokio::test]
303 async fn test_process_handle_debug() {
304 let exit_status = Arc::new(AtomicBool::new(false));
306 let exit_code = Arc::new(StdMutex::new(None));
307
308 let (writer_tx, _) = mpsc::channel(1);
309 let (output_tx, initial_rx) = broadcast::channel(1);
310
311 let (handle, _) = ProcessHandle::new(
312 writer_tx,
313 output_tx,
314 initial_rx,
315 Box::new(NoopTerminator),
316 tokio::spawn(async {}),
317 vec![],
318 tokio::spawn(async {}),
319 tokio::spawn(async {}),
320 exit_status,
321 exit_code,
322 None,
323 );
324
325 let debug_str = format!("{handle:?}");
326 assert!(debug_str.contains("ProcessHandle"));
327 }
328
329 #[tokio::test]
330 async fn test_has_exited() {
331 let exit_status = Arc::new(AtomicBool::new(false));
332 let exit_code = Arc::new(StdMutex::new(None));
333
334 let (writer_tx, _) = mpsc::channel(1);
335 let (output_tx, initial_rx) = broadcast::channel(1);
336
337 let (handle, _) = ProcessHandle::new(
338 writer_tx,
339 output_tx,
340 initial_rx,
341 Box::new(NoopTerminator),
342 tokio::spawn(async {}),
343 vec![],
344 tokio::spawn(async {}),
345 tokio::spawn(async {}),
346 Arc::clone(&exit_status),
347 exit_code,
348 None,
349 );
350
351 assert!(!handle.has_exited());
352 exit_status.store(true, Ordering::SeqCst);
353 assert!(handle.has_exited());
354 }
355}