1
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use russh::{Channel, ChannelId, ChannelMsg};
5use std::io;
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, SimplexStream, WriteHalf};
7use tokio::task::JoinHandle;
8use crate::{ExitStatus, ExitStatusImp};
9
10#[derive(Debug)]
13pub struct ChildStdin {
14 pub(crate) inner: WriteHalf<SimplexStream>,
15}
16
17impl AsyncWrite for ChildStdin {
18 fn poll_write(
19 self: Pin<&mut Self>,
20 cx: &mut Context<'_>,
21 buf: &[u8],
22 ) -> Poll<Result<usize, io::Error>> {
23 let this = self.get_mut();
24 Pin::new(&mut this.inner).poll_write(cx, buf)
25 }
26
27 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
28 let this = self.get_mut();
29 Pin::new(&mut this.inner).poll_flush(cx)
30 }
31
32 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
33 let this = self.get_mut();
34 Pin::new(&mut this.inner).poll_shutdown(cx)
35 }
36}
37
38#[derive(Debug)]
41pub struct ChildStdout {
42 pub(crate) inner: ReadHalf<SimplexStream>,
43}
44
45impl AsyncRead for ChildStdout {
46 fn poll_read(
47 self: Pin<&mut Self>,
48 cx: &mut Context,
49 buf: &mut ReadBuf,
50 ) -> Poll<Result<(), io::Error>> {
51 let this = self.get_mut();
52 Pin::new(&mut this.inner).poll_read(cx, buf)
53 }
54}
55
56#[derive(Debug)]
59pub struct ChildStderr {
60 pub(crate) inner: ReadHalf<SimplexStream>,
61}
62
63impl AsyncRead for ChildStderr {
64 fn poll_read(
65 self: Pin<&mut Self>,
66 cx: &mut Context,
67 buf: &mut ReadBuf,
68 ) -> Poll<Result<(), io::Error>> {
69 let this = self.get_mut();
70 Pin::new(&mut this.inner).poll_read(cx, buf)
71 }
72}
73
74#[derive(Debug)]
77pub struct Child {
78 pub stdin: Option<ChildStdin>,
79 pub stdout: Option<ChildStdout>,
80 pub stderr: Option<ChildStderr>,
81 pub(crate) handle: JoinHandle<Result<ExitStatus, io::Error>>,
82}
83
84#[derive(Debug)]
85pub(crate) struct ChildImp<S>
86where
87 S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static,
88{
89 pub(crate) channel: Channel<S>,
90 pub(crate) stdin_rx: ReadHalf<SimplexStream>,
91 pub(crate) stdout_tx: WriteHalf<SimplexStream>,
92 pub(crate) stderr_tx: WriteHalf<SimplexStream>,
93}
94
95impl<S> ChildImp<S>
96where
97 S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static,
98{
99 pub async fn wait(mut self) -> Result<ExitStatus, io::Error> {
100 use tokio::io::AsyncWriteExt;
101
102 let mut code = ExitStatusImp::Processing;
103
104 let mut writer = self.channel.make_writer_ext(None);
105 let mut stdin_rx = self.stdin_rx;
106 tokio::spawn(async move {
107 let _ = tokio::io::copy(&mut stdin_rx, &mut writer).await; });
109
110 loop {
111 let Some(msg) = self.channel.wait().await else {
112 break;
113 };
114 match msg {
115 ChannelMsg::ExitStatus { exit_status } => {
116 code = ExitStatusImp::Code(exit_status);
118 }
119 ChannelMsg::Data { ref data } => {
120 self.stdout_tx.write_all(data).await?;
121 }
122 ChannelMsg::ExtendedData { ref data, ext: 1 } => {
123 self.stderr_tx.write_all(data).await?;
124 }
125 _ => {}
126 }
127 }
128 tokio::try_join!(self.stdout_tx.shutdown(), self.stderr_tx.shutdown())?;
129 Ok(ExitStatus { inner: code })
130 }
131}
132
133impl Child {
134 pub async fn wait(self) -> Result<ExitStatus, io::Error> {
136 self.handle.await?
137 }
138
139 pub async fn wait_with_output(mut self) -> Result<Output, io::Error> {
141 async fn read_to_end<A: AsyncRead + Unpin>(io: &mut Option<A>) -> io::Result<Vec<u8>> {
142 use tokio::io::AsyncReadExt;
143 let mut vec = Vec::new();
144 if let Some(io) = io.as_mut() {
145 io.read_to_end(&mut vec).await?;
146 }
147 Ok(vec)
148 }
149
150 let mut stdout_pipe = self.stdout.take();
151 let mut stderr_pipe = self.stderr.take();
152
153 let stdout_fut = read_to_end(&mut stdout_pipe);
154 let stderr_fut = read_to_end(&mut stderr_pipe);
155
156 let (status, stdout, stderr) = tokio::try_join!(self.wait(), stdout_fut, stderr_fut)?;
157
158 drop(stdout_pipe);
159 drop(stderr_pipe);
160
161 Ok(Output {
162 status,
163 stdout,
164 stderr,
165 })
166 }
167}
168
169#[derive(Debug)]
172pub struct Output {
173 pub status: ExitStatus,
174 pub stdout: Vec<u8>,
175 pub stderr: Vec<u8>,
176}