1use std::convert::TryInto;
2use std::error::Error;
3use std::fmt;
4use std::os::unix::io::AsRawFd;
5
6use bytes::{BufMut, BytesMut};
7use sendfd::SendWithFd;
8use tokio::{
9 io::{AsyncReadExt, AsyncWriteExt},
10 net::UnixStream,
11};
12use tokio_pipe::{PipeRead, PipeWrite};
13
14use crate::commands::{
15 MuxCmd, MuxCmdCheckAlive, MuxCmdHello, MuxCmdNewSession, MuxRespCheckAlive,
16 MuxRespExit, MuxRespHello, MuxRespNewSession,
17};
18use crate::SshctlError;
19
20#[derive(PartialEq, Debug, Clone, Eq)]
23pub struct ShellResult {
24 pub stdout: Vec<u8>,
25 pub stderr: Vec<u8>,
26 pub exit_code: u32,
27}
28
29#[derive(Debug)]
31pub struct MuxError {
32 details: String,
33}
34
35impl MuxError {
36 fn new(details: String) -> Self {
37 Self { details }
38 }
39}
40
41impl fmt::Display for MuxError {
42 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43 write!(f, "{}", self.details)
44 }
45}
46
47impl Error for MuxError {}
48
49pub async fn run(
54 path: &str,
55 command: &str,
56) -> Result<ShellResult, SshctlError> {
57 run_stdin(path, command, None).await
58}
59
60pub async fn run_stdin(
68 ctlpath: &str,
69 command: &str,
70 stdin: Option<Vec<u8>>,
71) -> Result<ShellResult, SshctlError> {
72 let mut socket = UnixStream::connect(ctlpath).await?;
75
76 hello(&mut socket).await?;
77 let request_id = check_mux_alive(&mut socket, 0).await?;
78 let (session_id, local_stdin, local_stdout, local_stderr) =
79 new_session(&mut socket, request_id, command).await?;
80
81 let stdin_data = stdin.unwrap_or_default();
82
83 let (rx_rc, tx_stdin, rx_stdout, rx_stderr) = tokio::join! {
84 wait(&mut socket, session_id),
85 write_stdin(local_stdin, &stdin_data[..]),
86 read_ssh_pipe(local_stdout),
87 read_ssh_pipe(local_stderr),
88 };
89
90 tx_stdin?;
91
92 Ok(ShellResult {
93 stdout: rx_stdout?,
94 stderr: rx_stderr?,
95 exit_code: rx_rc?,
96 })
97}
98
99async fn read_packet_response(
100 socket: &mut UnixStream,
101) -> Result<Vec<u8>, std::io::Error> {
102 let mut buffer = [0; 4];
103
104 socket.read_exact(&mut buffer).await?;
105 let length = u32::from_be_bytes(buffer) as usize;
106 let mut response = vec![0; length];
108 socket.read_exact(&mut response[0..length]).await?;
109 Ok(response)
110}
111
112async fn write_command<T: MuxCmd>(
113 socket: &mut UnixStream,
114 command: &T,
115) -> Result<(), std::io::Error> {
116 let mut buffer = BytesMut::with_capacity(command.length() + 4);
117 buffer.put_u32(command.length().try_into().unwrap());
118 command.serialize(&mut buffer);
119
120 socket.write(&buffer).await.map(|_| ())
124}
125
126async fn hello(socket: &mut UnixStream) -> Result<(), SshctlError> {
127 let response = match read_packet_response(socket).await {
131 Ok(x) => MuxRespHello::deserialize(&mut x.as_slice())?,
132 Err(e) => {
133 return Err(MuxError::new(format!(
134 "Read MuxRespHello failed: {:?}",
135 e
136 ))
137 .into())
138 }
139 };
140
141 if !response.is_valid() {
142 return Err(MuxError::new(format!(
143 "Received invalid hello message: {:?}",
144 response
145 ))
146 .into());
147 }
148
149 let command = MuxCmdHello {};
150
151 if let Err(e) = write_command(socket, &command).await {
152 return Err(MuxError::new(format!(
153 "Write MuxCmdHello failed: {:?}",
154 e
155 ))
156 .into());
157 }
158 Ok(())
159}
160
161async fn check_mux_alive(
162 socket: &mut UnixStream,
163 request_id: u32,
164) -> Result<u32, SshctlError> {
165 let command = MuxCmdCheckAlive::new(request_id);
169
170 if let Err(e) = write_command(socket, &command).await {
171 return Err(MuxError::new(format!(
172 "Write check alive request failed: {:?}",
173 e
174 ))
175 .into());
176 }
177
178 let response = match read_packet_response(socket).await {
179 Ok(x) => MuxRespCheckAlive::deserialize(&mut x.as_slice())?,
180 Err(e) => {
181 return Err(MuxError::new(format!(
182 "Read MuxRespCheckAlive failed: {:?}",
183 e
184 ))
185 .into())
186 }
187 };
188
189 if !response.is_valid(request_id) {
190 return Err(MuxError::new(format!(
191 "Received invalid check_alive message: {:?}",
192 response
193 ))
194 .into());
195 }
196
197 Ok(request_id + 1)
198}
199
200async fn new_session(
201 socket: &mut UnixStream,
202 request_id: u32,
203 command: &str,
204) -> Result<(u32, PipeWrite, PipeRead, PipeRead), SshctlError> {
205 let command = MuxCmdNewSession::new(request_id, command.into());
209
210 if let Err(e) = write_command(socket, &command).await {
211 return Err(MuxError::new(format!(
212 "Write new session request failed: {:?}",
213 e
214 ))
215 .into());
216 }
217
218 let (remote_stdin, local_stdin) = tokio_pipe::pipe()?;
219 let (local_stdout, remote_stdout) = tokio_pipe::pipe()?;
220 let (local_stderr, remote_stderr) = tokio_pipe::pipe()?;
221
222 let fds: [i32; 1] = [remote_stdin.as_raw_fd()];
223 if let Err(e) = socket.send_with_fd(b" ", &fds) {
224 return Err(
225 MuxError::new(format!("send_with_fd failed: {:?}", e)).into()
226 );
227 }
228
229 let fds: [i32; 1] = [remote_stdout.as_raw_fd()];
230 if let Err(e) = socket.send_with_fd(b" ", &fds) {
231 return Err(
232 MuxError::new(format!("send_with_fd failed: {:?}", e)).into()
233 );
234 }
235
236 let fds: [i32; 1] = [remote_stderr.as_raw_fd()];
237 if let Err(e) = socket.send_with_fd(b" ", &fds) {
238 return Err(
239 MuxError::new(format!("send_with_fd failed: {:?}", e)).into()
240 );
241 }
242
243 let response = match read_packet_response(socket).await {
244 Ok(x) => MuxRespNewSession::deserialize(&mut x.as_slice())?,
245 Err(e) => {
246 return Err(MuxError::new(format!(
247 "Read MuxRespNewSession failed: {:?}",
248 e
249 ))
250 .into())
251 }
252 };
253
254 if !response.is_valid(request_id) {
255 return Err(MuxError::new(format!(
256 "Received invalid new_session message: {:?}",
257 response
258 ))
259 .into());
260 }
261
262 Ok((
263 response.session_id(),
264 local_stdin,
265 local_stdout,
266 local_stderr,
267 ))
268}
269
270async fn wait(
271 socket: &mut UnixStream,
272 session_id: u32,
273) -> Result<u32, SshctlError> {
274 let response = match read_packet_response(socket).await {
275 Ok(x) => MuxRespExit::deserialize(&mut x.as_slice())?,
276 Err(e) => {
277 return Err(MuxError::new(format!(
278 "Read MuxRespExit failed: {:?}",
279 e
280 ))
281 .into())
282 }
283 };
284
285 if !response.is_valid(session_id) {
286 return Err(MuxError::new(format!(
287 "Received invalid exit message: {:?}",
288 response
289 ))
290 .into());
291 }
292
293 Ok(response.exit_code())
294}
295
296async fn write_stdin(
297 mut local_stdin: PipeWrite,
298 buffer: &[u8],
299) -> Result<(), MuxError> {
300 if let Err(e) = local_stdin.write_all(buffer).await {
301 return Err(MuxError::new(format!("Write stdin failed: {:?}", e)));
302 }
303 Ok(())
304}
305
306async fn read_ssh_pipe(mut pipe: PipeRead) -> Result<Vec<u8>, MuxError> {
307 let mut data = Vec::<u8>::with_capacity(1024);
308 let mut buffer = [0; 1024];
309 loop {
310 match pipe.read(&mut buffer).await {
311 Ok(count) => {
312 if count == 0 {
315 return Ok(data);
316 }
317 data.append(&mut buffer[..count].to_vec());
318 }
319 Err(e) => {
320 return Err(MuxError::new(format!(
321 "Read stdout failed: {:?}",
322 e
323 )))
324 }
325 };
326 }
327}