1use std::fmt;
9use std::io;
10use std::sync::Arc;
11use std::sync::Mutex as StdMutex;
12use std::sync::atomic::{AtomicBool, Ordering};
13
14use tokio::sync::{broadcast, mpsc, oneshot};
15use tokio::task::{AbortHandle, JoinHandle};
16
17pub trait ChildTerminator: Send + Sync {
21 fn kill(&mut self) -> io::Result<()>;
23}
24
25pub struct PtyHandles {
30 pub _slave: Option<Box<dyn Send>>,
32 pub _master: Box<dyn Send>,
34}
35
36impl fmt::Debug for PtyHandles {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 f.debug_struct("PtyHandles").finish()
39 }
40}
41
42pub struct ProcessHandle {
50 writer_tx: mpsc::Sender<Vec<u8>>,
51 output_tx: broadcast::Sender<Vec<u8>>,
52 killer: StdMutex<Option<Box<dyn ChildTerminator>>>,
53 reader_handle: StdMutex<Option<JoinHandle<()>>>,
54 reader_abort_handles: StdMutex<Vec<AbortHandle>>,
55 writer_handle: StdMutex<Option<JoinHandle<()>>>,
56 wait_handle: StdMutex<Option<JoinHandle<()>>>,
57 exit_status: Arc<AtomicBool>,
58 exit_code: Arc<StdMutex<Option<i32>>>,
59 _pty_handles: StdMutex<Option<PtyHandles>>,
61}
62
63impl fmt::Debug for ProcessHandle {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("ProcessHandle")
66 .field("has_exited", &self.has_exited())
67 .field("exit_code", &self.exit_code())
68 .finish()
69 }
70}
71
72impl ProcessHandle {
73 #[allow(clippy::too_many_arguments)]
75 pub fn new(
76 writer_tx: mpsc::Sender<Vec<u8>>,
77 output_tx: broadcast::Sender<Vec<u8>>,
78 initial_output_rx: broadcast::Receiver<Vec<u8>>,
79 killer: Box<dyn ChildTerminator>,
80 reader_handle: JoinHandle<()>,
81 reader_abort_handles: Vec<AbortHandle>,
82 writer_handle: JoinHandle<()>,
83 wait_handle: JoinHandle<()>,
84 exit_status: Arc<AtomicBool>,
85 exit_code: Arc<StdMutex<Option<i32>>>,
86 pty_handles: Option<PtyHandles>,
87 ) -> (Self, broadcast::Receiver<Vec<u8>>) {
88 (
89 Self {
90 writer_tx,
91 output_tx,
92 killer: StdMutex::new(Some(killer)),
93 reader_handle: StdMutex::new(Some(reader_handle)),
94 reader_abort_handles: StdMutex::new(reader_abort_handles),
95 writer_handle: StdMutex::new(Some(writer_handle)),
96 wait_handle: StdMutex::new(Some(wait_handle)),
97 exit_status,
98 exit_code,
99 _pty_handles: StdMutex::new(pty_handles),
100 },
101 initial_output_rx,
102 )
103 }
104
105 pub fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
113 self.writer_tx.clone()
114 }
115
116 pub fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
121 self.output_tx.subscribe()
122 }
123
124 pub fn has_exited(&self) -> bool {
126 self.exit_status.load(Ordering::SeqCst)
127 }
128
129 pub fn exit_code(&self) -> Option<i32> {
131 self.exit_code.lock().ok().and_then(|guard| *guard)
132 }
133
134 pub fn terminate(&self) {
138 self.terminate_internal();
139 }
140
141 fn terminate_internal(&self) {
143 if let Ok(mut killer_opt) = self.killer.lock()
145 && let Some(mut killer) = killer_opt.take()
146 {
147 let _ = killer.kill();
148 }
149
150 self.abort_tasks();
151 }
152
153 fn abort_tasks(&self) {
155 if let Ok(mut h) = self.reader_handle.lock()
157 && let Some(handle) = h.take()
158 {
159 handle.abort();
160 }
161
162 if let Ok(mut handles) = self.reader_abort_handles.lock() {
164 for handle in handles.drain(..) {
165 handle.abort();
166 }
167 }
168
169 if let Ok(mut h) = self.writer_handle.lock()
171 && let Some(handle) = h.take()
172 {
173 handle.abort();
174 }
175
176 if let Ok(mut h) = self.wait_handle.lock()
178 && let Some(handle) = h.take()
179 {
180 handle.abort();
181 }
182 }
183
184 pub fn is_running(&self) -> bool {
186 !self.has_exited() && !self.is_writer_closed()
187 }
188
189 pub async fn write(
193 &self,
194 bytes: impl Into<Vec<u8>>,
195 ) -> Result<(), mpsc::error::SendError<Vec<u8>>> {
196 self.writer_tx.send(bytes.into()).await
197 }
198
199 pub fn is_writer_closed(&self) -> bool {
201 self.writer_tx.is_closed()
202 }
203}
204
205impl Drop for ProcessHandle {
206 fn drop(&mut self) {
207 self.terminate_internal();
208 }
209}
210
211#[derive(Debug)]
215pub struct SpawnedProcess {
216 pub session: ProcessHandle,
218 pub output_rx: broadcast::Receiver<Vec<u8>>,
220 pub exit_rx: oneshot::Receiver<i32>,
222}
223
224impl SpawnedProcess {
225 pub async fn wait_with_output(self, timeout_ms: u64) -> (Vec<u8>, i32) {
229 collect_output_until_exit(self.output_rx, self.exit_rx, timeout_ms).await
230 }
231}
232
233pub async fn collect_output_until_exit(
237 mut output_rx: broadcast::Receiver<Vec<u8>>,
238 exit_rx: oneshot::Receiver<i32>,
239 timeout_ms: u64,
240) -> (Vec<u8>, i32) {
241 let mut collected = Vec::new();
242 let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
243 tokio::pin!(exit_rx);
244
245 loop {
246 tokio::select! {
247 res = output_rx.recv() => {
248 if let Ok(chunk) = res {
249 collected.extend_from_slice(&chunk);
250 }
251 }
252 res = &mut exit_rx => {
253 let code = res.unwrap_or(-1);
254 let quiet = tokio::time::Duration::from_millis(50);
256 let max_deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(500);
257
258 while tokio::time::Instant::now() < max_deadline {
259 match tokio::time::timeout(quiet, output_rx.recv()).await {
260 Ok(Ok(chunk)) => collected.extend_from_slice(&chunk),
261 Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue,
262 Ok(Err(broadcast::error::RecvError::Closed)) => break,
263 Err(_) => break, }
265 }
266 return (collected, code);
267 }
268 _ = tokio::time::sleep_until(deadline) => {
269 return (collected, -1);
270 }
271 }
272 }
273}
274
275pub type ExecCommandSession = ProcessHandle;
277
278pub type SpawnedPty = SpawnedProcess;
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 struct NoopTerminator;
286 impl ChildTerminator for NoopTerminator {
287 fn kill(&mut self) -> io::Result<()> {
288 Ok(())
289 }
290 }
291
292 #[tokio::test]
293 async fn test_process_handle_debug() {
294 let exit_status = Arc::new(AtomicBool::new(false));
296 let exit_code = Arc::new(StdMutex::new(None));
297
298 let (writer_tx, _) = mpsc::channel(1);
299 let (output_tx, initial_rx) = broadcast::channel(1);
300
301 let (handle, _) = ProcessHandle::new(
302 writer_tx,
303 output_tx,
304 initial_rx,
305 Box::new(NoopTerminator),
306 tokio::spawn(async {}),
307 vec![],
308 tokio::spawn(async {}),
309 tokio::spawn(async {}),
310 exit_status,
311 exit_code,
312 None,
313 );
314
315 let debug_str = format!("{handle:?}");
316 assert!(debug_str.contains("ProcessHandle"));
317 }
318
319 #[tokio::test]
320 async fn test_has_exited() {
321 let exit_status = Arc::new(AtomicBool::new(false));
322 let exit_code = Arc::new(StdMutex::new(None));
323
324 let (writer_tx, _) = mpsc::channel(1);
325 let (output_tx, initial_rx) = broadcast::channel(1);
326
327 let (handle, _) = ProcessHandle::new(
328 writer_tx,
329 output_tx,
330 initial_rx,
331 Box::new(NoopTerminator),
332 tokio::spawn(async {}),
333 vec![],
334 tokio::spawn(async {}),
335 tokio::spawn(async {}),
336 Arc::clone(&exit_status),
337 exit_code,
338 None,
339 );
340
341 assert!(!handle.has_exited());
342 exit_status.store(true, Ordering::SeqCst);
343 assert!(handle.has_exited());
344 }
345}