ssh_muxcontrol/
session.rs

1use std::convert::TryInto;
2use std::error::Error;
3use std::fmt;
4use std::os::unix::io::AsRawFd;
5
6use bytes::{BufMut, BytesMut};
7use sendfd::SendWithFd;
8use tokio::{
9    io::{AsyncReadExt, AsyncWriteExt},
10    net::UnixStream,
11};
12use tokio_pipe::{PipeRead, PipeWrite};
13
14use crate::commands::{
15    MuxCmd, MuxCmdCheckAlive, MuxCmdHello, MuxCmdNewSession, MuxRespCheckAlive,
16    MuxRespExit, MuxRespHello, MuxRespNewSession,
17};
18use crate::SshctlError;
19
20/// A simple struct which contains the stdout, stderr and exit code
21/// of a completed remote command.
22#[derive(PartialEq, Debug, Clone, Eq)]
23pub struct ShellResult {
24    pub stdout: Vec<u8>,
25    pub stderr: Vec<u8>,
26    pub exit_code: u32,
27}
28
29/// SSH control socket errors.
30#[derive(Debug)]
31pub struct MuxError {
32    details: String,
33}
34
35impl MuxError {
36    fn new(details: String) -> Self {
37        Self { details }
38    }
39}
40
41impl fmt::Display for MuxError {
42    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43        write!(f, "{}", self.details)
44    }
45}
46
47impl Error for MuxError {}
48
49/// Runs a given shell command on the remote hosts default shell
50/// though an existing SSH UNIX control socket.
51/// The SSH control socket is created
52/// outside of this crate by an existing SSH connection.
53pub async fn run(
54    path: &str,
55    command: &str,
56) -> Result<ShellResult, SshctlError> {
57    run_stdin(path, command, None).await
58}
59
60/// Runs a given shell command on the remote hosts default shell
61/// though an existing SSH UNIX control socket.
62/// The SSH control socket is created
63/// outside of this crate by an existing SSH connection.
64///
65/// This function is the same as `run` but custom data is supplied to
66/// the remote commands STDIN.
67pub async fn run_stdin(
68    ctlpath: &str,
69    command: &str,
70    stdin: Option<Vec<u8>>,
71) -> Result<ShellResult, SshctlError> {
72    //#[cfg(debug_assertions)]
73    //eprintln!("Mux::request_session: {}", ctlpath);
74    let mut socket = UnixStream::connect(ctlpath).await?;
75
76    hello(&mut socket).await?;
77    let request_id = check_mux_alive(&mut socket, 0).await?;
78    let (session_id, local_stdin, local_stdout, local_stderr) =
79        new_session(&mut socket, request_id, command).await?;
80
81    let stdin_data = stdin.unwrap_or_default();
82
83    let (rx_rc, tx_stdin, rx_stdout, rx_stderr) = tokio::join! {
84        wait(&mut socket, session_id),
85        write_stdin(local_stdin, &stdin_data[..]),
86        read_ssh_pipe(local_stdout),
87        read_ssh_pipe(local_stderr),
88    };
89
90    tx_stdin?;
91
92    Ok(ShellResult {
93        stdout: rx_stdout?,
94        stderr: rx_stderr?,
95        exit_code: rx_rc?,
96    })
97}
98
99async fn read_packet_response(
100    socket: &mut UnixStream,
101) -> Result<Vec<u8>, std::io::Error> {
102    let mut buffer = [0; 4];
103
104    socket.read_exact(&mut buffer).await?;
105    let length = u32::from_be_bytes(buffer) as usize;
106    //eprintln!("Received length: {}", &length);
107    let mut response = vec![0; length];
108    socket.read_exact(&mut response[0..length]).await?;
109    Ok(response)
110}
111
112async fn write_command<T: MuxCmd>(
113    socket: &mut UnixStream,
114    command: &T,
115) -> Result<(), std::io::Error> {
116    let mut buffer = BytesMut::with_capacity(command.length() + 4);
117    buffer.put_u32(command.length().try_into().unwrap());
118    command.serialize(&mut buffer);
119
120    //#[cfg(debug_assertions)]
121    //eprintln!("writing mux command: {:?}", &buffer);
122
123    socket.write(&buffer).await.map(|_| ())
124}
125
126async fn hello(socket: &mut UnixStream) -> Result<(), SshctlError> {
127    //#[cfg(debug_assertions)]
128    //eprintln!("Mux::hello");
129
130    let response = match read_packet_response(socket).await {
131        Ok(x) => MuxRespHello::deserialize(&mut x.as_slice())?,
132        Err(e) => {
133            return Err(MuxError::new(format!(
134                "Read MuxRespHello failed: {:?}",
135                e
136            ))
137            .into())
138        }
139    };
140
141    if !response.is_valid() {
142        return Err(MuxError::new(format!(
143            "Received invalid hello message: {:?}",
144            response
145        ))
146        .into());
147    }
148
149    let command = MuxCmdHello {};
150
151    if let Err(e) = write_command(socket, &command).await {
152        return Err(MuxError::new(format!(
153            "Write MuxCmdHello failed: {:?}",
154            e
155        ))
156        .into());
157    }
158    Ok(())
159}
160
161async fn check_mux_alive(
162    socket: &mut UnixStream,
163    request_id: u32,
164) -> Result<u32, SshctlError> {
165    //#[cfg(debug_assertions)]
166    //eprintln!("Mux::check_alive");
167
168    let command = MuxCmdCheckAlive::new(request_id);
169
170    if let Err(e) = write_command(socket, &command).await {
171        return Err(MuxError::new(format!(
172            "Write check alive request failed: {:?}",
173            e
174        ))
175        .into());
176    }
177
178    let response = match read_packet_response(socket).await {
179        Ok(x) => MuxRespCheckAlive::deserialize(&mut x.as_slice())?,
180        Err(e) => {
181            return Err(MuxError::new(format!(
182                "Read MuxRespCheckAlive failed: {:?}",
183                e
184            ))
185            .into())
186        }
187    };
188
189    if !response.is_valid(request_id) {
190        return Err(MuxError::new(format!(
191            "Received invalid check_alive message: {:?}",
192            response
193        ))
194        .into());
195    }
196
197    Ok(request_id + 1)
198}
199
200async fn new_session(
201    socket: &mut UnixStream,
202    request_id: u32,
203    command: &str,
204) -> Result<(u32, PipeWrite, PipeRead, PipeRead), SshctlError> {
205    //#[cfg(debug_assertions)]
206    //eprintln!("Mux::new_session");
207
208    let command = MuxCmdNewSession::new(request_id, command.into());
209
210    if let Err(e) = write_command(socket, &command).await {
211        return Err(MuxError::new(format!(
212            "Write new session request failed: {:?}",
213            e
214        ))
215        .into());
216    }
217
218    let (remote_stdin, local_stdin) = tokio_pipe::pipe()?;
219    let (local_stdout, remote_stdout) = tokio_pipe::pipe()?;
220    let (local_stderr, remote_stderr) = tokio_pipe::pipe()?;
221
222    let fds: [i32; 1] = [remote_stdin.as_raw_fd()];
223    if let Err(e) = socket.send_with_fd(b" ", &fds) {
224        return Err(
225            MuxError::new(format!("send_with_fd failed: {:?}", e)).into()
226        );
227    }
228
229    let fds: [i32; 1] = [remote_stdout.as_raw_fd()];
230    if let Err(e) = socket.send_with_fd(b" ", &fds) {
231        return Err(
232            MuxError::new(format!("send_with_fd failed: {:?}", e)).into()
233        );
234    }
235
236    let fds: [i32; 1] = [remote_stderr.as_raw_fd()];
237    if let Err(e) = socket.send_with_fd(b" ", &fds) {
238        return Err(
239            MuxError::new(format!("send_with_fd failed: {:?}", e)).into()
240        );
241    }
242
243    let response = match read_packet_response(socket).await {
244        Ok(x) => MuxRespNewSession::deserialize(&mut x.as_slice())?,
245        Err(e) => {
246            return Err(MuxError::new(format!(
247                "Read MuxRespNewSession failed: {:?}",
248                e
249            ))
250            .into())
251        }
252    };
253
254    if !response.is_valid(request_id) {
255        return Err(MuxError::new(format!(
256            "Received invalid new_session message: {:?}",
257            response
258        ))
259        .into());
260    }
261
262    Ok((
263        response.session_id(),
264        local_stdin,
265        local_stdout,
266        local_stderr,
267    ))
268}
269
270async fn wait(
271    socket: &mut UnixStream,
272    session_id: u32,
273) -> Result<u32, SshctlError> {
274    let response = match read_packet_response(socket).await {
275        Ok(x) => MuxRespExit::deserialize(&mut x.as_slice())?,
276        Err(e) => {
277            return Err(MuxError::new(format!(
278                "Read MuxRespExit failed: {:?}",
279                e
280            ))
281            .into())
282        }
283    };
284
285    if !response.is_valid(session_id) {
286        return Err(MuxError::new(format!(
287            "Received invalid exit message: {:?}",
288            response
289        ))
290        .into());
291    }
292
293    Ok(response.exit_code())
294}
295
296async fn write_stdin(
297    mut local_stdin: PipeWrite,
298    buffer: &[u8],
299) -> Result<(), MuxError> {
300    if let Err(e) = local_stdin.write_all(buffer).await {
301        return Err(MuxError::new(format!("Write stdin failed: {:?}", e)));
302    }
303    Ok(())
304}
305
306async fn read_ssh_pipe(mut pipe: PipeRead) -> Result<Vec<u8>, MuxError> {
307    let mut data = Vec::<u8>::with_capacity(1024);
308    let mut buffer = [0; 1024];
309    loop {
310        match pipe.read(&mut buffer).await {
311            Ok(count) => {
312                //#[cfg(debug_assertions)]
313                //eprintln!("received from pipe: {}: {:?}", count, data);
314                if count == 0 {
315                    return Ok(data);
316                }
317                data.append(&mut buffer[..count].to_vec());
318            }
319            Err(e) => {
320                return Err(MuxError::new(format!(
321                    "Read stdout failed: {:?}",
322                    e
323                )))
324            }
325        };
326    }
327}