sunset_stdasync/
cmdline_client.rs

1use futures::pin_mut;
2#[allow(unused_imports)]
3use log::{debug, error, info, log, trace, warn};
4use sunset::event::CliEvent;
5
6use core::fmt::Debug;
7use core::str::FromStr;
8
9use sunset::{sshnames, AuthSigMsg, OwnedSig, Pty, SignKey};
10use sunset::{Error, Result, Runner, SessionCommand};
11use sunset_async::*;
12
13use embassy_sync::channel::{Channel, Receiver, Sender};
14use embassy_sync::signal::Signal;
15use embedded_io_async::{Read as _, Write as _};
16use std::collections::VecDeque;
17
18use tokio::io::AsyncReadExt;
19use tokio::io::AsyncWriteExt;
20use tokio::signal::unix::{signal, SignalKind};
21
22use futures::FutureExt;
23use futures::{future::Fuse, select_biased};
24
25use crate::pty::win_size;
26use crate::AgentClient;
27use crate::*;
28use crate::{raw_pty, RawPtyGuard};
29
30/// A commandline client implementation
31///
32/// This opens a single channel and presents it to the stdin/stdout terminal.
33pub struct CmdlineClient {
34    cmd: SessionCommand<String>,
35    want_pty: bool,
36
37    // parameters
38    authkeys: VecDeque<SignKey>,
39    username: String,
40    host: String,
41    port: u16,
42    agent: Option<AgentClient>,
43
44    pty_guard: Option<RawPtyGuard>,
45
46    pty: Option<Pty>,
47}
48
49impl CmdlineClient {
50    pub fn new(username: impl AsRef<str>, host: impl AsRef<str>) -> Self {
51        Self {
52            cmd: SessionCommand::Shell,
53            want_pty: false,
54            agent: None,
55
56            username: username.as_ref().into(),
57            host: host.as_ref().into(),
58            port: sshnames::SSH_PORT,
59            authkeys: Default::default(),
60            pty: None,
61            pty_guard: None,
62        }
63    }
64
65    pub fn port(&mut self, port: u16) -> &mut Self {
66        self.port = port;
67        self
68    }
69
70    pub fn pty(&mut self) -> &mut Self {
71        match pty::current_pty() {
72            Ok(p) => {
73                self.pty = Some(p);
74                self.want_pty = true;
75            }
76            Err(e) => warn!("Failed getting current pty: {e:?}"),
77        };
78        self
79    }
80
81    pub fn exec(&mut self, cmd: impl AsRef<str>) -> &mut Self {
82        self.cmd = SessionCommand::Exec(cmd.as_ref().into());
83        self
84    }
85
86    pub fn subsystem(&mut self, subsystem: impl AsRef<str>) -> &mut Self {
87        self.cmd = SessionCommand::Subsystem(subsystem.as_ref().into());
88        self
89    }
90
91    pub fn add_authkey(&mut self, k: SignKey) {
92        self.authkeys.push_back(k)
93    }
94
95    pub fn agent(&mut self, agent: AgentClient) {
96        self.agent = Some(agent)
97    }
98
99    async fn chan_run(
100        io: ChanInOut<'_>,
101        io_err: Option<ChanIn<'_>>,
102        pty_guard: Option<RawPtyGuard>,
103    ) -> Result<()> {
104        // out
105        let fo = async {
106            let mut io = io.clone();
107            let mut so =
108                crate::stdout().map_err(|_| Error::msg("opening stdout failed"))?;
109            loop {
110                // TODO buffers
111                let mut buf = [0u8; 1000];
112                let l = io.read(&mut buf).await?;
113                if l == 0 {
114                    break;
115                }
116                so.write_all(&buf[..l]).await.map_err(|_| Error::ChannelEOF)?;
117            }
118            #[allow(unreachable_code)]
119            Ok::<_, sunset::Error>(())
120        };
121
122        // err
123        let fe = async {
124            // if io_err is None we complete immediately
125            if let Some(mut errin) = io_err {
126                let mut eo = crate::stderr_out()
127                    .map_err(|_e| Error::msg("opening stderr failed"))?;
128                loop {
129                    // TODO buffers
130                    let mut buf = [0u8; 1000];
131                    let l = errin.read(&mut buf).await?;
132                    if l == 0 {
133                        break;
134                    }
135                    eo.write_all(&buf[..l]).await.map_err(|_| Error::ChannelEOF)?;
136                }
137                #[allow(unreachable_code)]
138                Ok::<_, sunset::Error>(())
139            } else {
140                Ok(())
141            }
142        };
143
144        let terminate = Signal::<SunsetRawMutex, ()>::new();
145
146        // in
147        let fi = async {
148            let mut io = io.clone();
149            let mut si =
150                crate::stdin().map_err(|_| Error::msg("opening stdin failed"))?;
151            let mut esc =
152                if pty_guard.is_some() { Some(Escaper::new()) } else { None };
153
154            loop {
155                // TODO buffers
156                let mut buf = [0u8; 1000];
157                let l = si.read(&mut buf).await.map_err(|_| Error::ChannelEOF)?;
158                if l == 0 {
159                    return Err(Error::ChannelEOF);
160                }
161
162                let buf = &buf[..l];
163
164                if let Some(ref mut esc) = esc {
165                    let a = esc.escape(buf);
166                    match a {
167                        EscapeAction::None => (),
168                        EscapeAction::Output { extra } => {
169                            if let Some(e) = extra {
170                                io.write_all(&[e]).await?;
171                            }
172                            io.write_all(buf).await?;
173                        }
174                        EscapeAction::Terminate => {
175                            info!("Terminated");
176                            terminate.signal(());
177                            return Ok(());
178                        }
179                        EscapeAction::Suspend => {
180                            // disabled for the time being, doesn't resume OK.
181                            // perhaps a bad interaction with dup_async(),
182                            // maybe the new guard needs to be on the dup-ed
183                            // FDs?
184
185                            // pty_guard = None;
186                            // nix::sys::signal::raise(nix::sys::signal::Signal::SIGTSTP)
187                            // .unwrap_or_else(|e| {
188                            //     warn!("Failed to stop: {e:?}");
189                            // });
190                            // // suspended here until resumed externally
191                            // set_pty_guard(&mut pty_guard);
192                            // continue;
193                        }
194                    }
195                } else {
196                    io.write_all(buf).await?;
197                }
198            }
199            #[allow(unreachable_code)]
200            Ok::<_, sunset::Error>(())
201        };
202
203        // output needs to complete when the channel is closed
204        let fi = embassy_futures::select::select(fi, io.until_closed());
205
206        // let fo = fo.map(|x| {
207        //     error!("fo done {x:?}");
208        //     x
209        // });
210        // let fi = fi.map(|x| {
211        //     error!("fi done {x:?}");
212        //     x
213        // });
214        // let fe = fe.map(|x| {
215        //     error!("fe done {x:?}");
216        //     x
217        // });
218
219        let io_done = embassy_futures::join::join3(fe, fi, fo);
220        let _ = embassy_futures::select::select(io_done, terminate.wait()).await;
221        // TODO handle errors from the join?
222        Ok(())
223    }
224
225    /// Runs the `CmdlineClient` session to completion.
226    ///
227    /// Performs authentication, requests a shell or command, performs channel IO.
228    /// Will return `Ok` after the session ends normally, or an error.
229    pub async fn run<'g: 'a, 'a>(&mut self, cli: &'g SSHClient<'a>) -> Result<i32> {
230        let mut winch_signal = self
231            .want_pty
232            .then(|| signal(SignalKind::window_change()))
233            .transpose()
234            .unwrap_or_else(|_| {
235                warn!("Couldn't watch for window change signals");
236                None
237            });
238
239        let mut io = None;
240        let mut extin = None;
241
242        let launch_chan: Channel<
243            SunsetRawMutex,
244            (ChanInOut, Option<ChanIn>, Option<RawPtyGuard>),
245            1,
246        > = Channel::new();
247
248        let mut exit_code = 1i32;
249
250        let prog_loop = async {
251            loop {
252                let winch_fut = Fuse::terminated();
253                pin_mut!(winch_fut);
254                if let Some(w) = winch_signal.as_mut() {
255                    winch_fut.set(w.recv().fuse());
256                }
257
258                let mut ph = ProgressHolder::new();
259                let ev = cli.progress(&mut ph).await?;
260                // Note that while ph is held, calls to cli will block.
261                match ev {
262                    CliEvent::Hostkey(h) => {
263                        let key = h.hostkey()?;
264                        match knownhosts::check_known_hosts(
265                            &self.host, self.port, &key,
266                        ) {
267                            Ok(()) => h.accept(),
268                            Err(_e) => h.reject(),
269                        }?;
270                    }
271                    CliEvent::Username(u) => {
272                        u.username(&self.username)?;
273                    }
274                    CliEvent::Password(p) => {
275                        let pw = rpassword::prompt_password(format!(
276                            "password for {}: ",
277                            self.username
278                        ))?;
279                        p.password(pw)?;
280                    }
281                    CliEvent::Pubkey(p) => {
282                        if let Some(k) = self.authkeys.pop_front() {
283                            p.pubkey(k)
284                        } else {
285                            p.skip()
286                        }?;
287                    }
288                    CliEvent::AgentSign(k) => {
289                        let agent =
290                            self.agent.as_mut().expect("agent keys without agent?");
291                        let key = k.key()?;
292                        let msg = k.message()?;
293                        let sig = agent.sign_auth(key, &msg).await?;
294                        k.signed(&sig)?;
295                    }
296                    CliEvent::Authenticated => {
297                        debug!("Authentication succeeded");
298                        // drop it so we can use cli
299                        drop(ph);
300                        let (i, e) = self.open_session(cli).await?;
301                        io = Some(i);
302                        extin = e;
303                    }
304                    CliEvent::SessionOpened(mut opener) => {
305                        if let Some(p) = self.pty.take() {
306                            opener.pty(p)?;
307                        }
308                        opener.cmd(&self.cmd)?;
309                        // Start the IO loop
310                        // TODO is there a better way
311                        launch_chan
312                            .send((
313                                io.take().unwrap(),
314                                extin.take(),
315                                self.pty_guard.take(),
316                            ))
317                            .await;
318                    }
319                    CliEvent::SessionExit(ex) => {
320                        trace!("session exit {ex:?}");
321                        if let sunset::CliSessionExit::Status(u) = ex {
322                            if u <= 255 {
323                                exit_code =
324                                    i8::from_be_bytes([(u & 0xff) as u8]) as i32;
325                            } else {
326                                exit_code = 1;
327                            }
328                        }
329                    }
330                    CliEvent::Banner(b) => {
331                        println!("Banner from server:\n{}", b.banner()?)
332                    }
333                    CliEvent::Defunct => {
334                        trace!("break defunct");
335                        break Ok::<_, Error>(());
336                    }
337                    CliEvent::PollAgain => (),
338                }
339            }
340        };
341
342        let chanio = async {
343            let (io, extin, pty) = launch_chan.receive().await;
344            Self::chan_run(io, extin, pty).await
345        };
346
347        embassy_futures::select::select(prog_loop, chanio).await;
348
349        Ok(exit_code)
350    }
351
352    /// Requests a PTY or non-PTY session
353    ///
354    /// Sets up the PTY if required.
355    async fn open_session<'g: 'a, 'a>(
356        &mut self,
357        cli: &'g SSHClient<'a>,
358    ) -> Result<(ChanInOut<'g>, Option<ChanIn<'g>>)> {
359        trace!("opens s");
360        let (io, extin) = if self.want_pty {
361            set_pty_guard(&mut self.pty_guard);
362            let io = cli.open_session_pty().await?;
363            (io, None)
364        } else {
365            let (io, extin) = cli.open_session_nopty().await?;
366            (io, Some(extin))
367        };
368        Ok((io, extin))
369    }
370}
371
372fn set_pty_guard(pty_guard: &mut Option<RawPtyGuard>) {
373    match raw_pty() {
374        Ok(p) => *pty_guard = Some(p),
375        Err(e) => {
376            warn!("Failed getting raw pty: {e:?}");
377        }
378    }
379}
380
381#[derive(Debug, PartialEq)]
382enum EscapeAction {
383    None,
384    // an extra character of output to prepend
385    Output { extra: Option<u8> },
386    Terminate,
387    Suspend,
388}
389
390/// Handles ~. escape sequences in an interactive shell.
391#[derive(Debug)]
392enum Escaper {
393    Idle,
394    Newline,
395    Escape,
396}
397
398impl Escaper {
399    fn new() -> Self {
400        // start as if we had received a '\r'
401        Self::Newline
402    }
403
404    /// Handle ~. escape sequences.
405    fn escape(&mut self, buf: &[u8]) -> EscapeAction {
406        // Only handle single input keystrokes. Provides some protection against
407        // pasting escape sequences too.
408
409        let mut newline = false;
410        if buf.len() == 1 {
411            let c = buf[0];
412            newline = c == b'\r';
413
414            match self {
415                Self::Newline if c == b'~' => {
416                    *self = Self::Escape;
417                    return EscapeAction::None;
418                }
419                Self::Escape => {
420                    // handle the actual escape character
421                    match c {
422                        b'~' => {
423                            // output the single '~' in buf.
424                            *self = Self::Idle;
425                            return EscapeAction::Output { extra: None };
426                        }
427                        b'.' => {
428                            *self = Self::Idle;
429                            return EscapeAction::Terminate;
430                        }
431                        // ctrl-z, suspend
432                        0x1a => {
433                            *self = Self::Idle;
434                            return EscapeAction::Suspend;
435                        }
436                        // fall through to reset below
437                        _ => (),
438                    }
439                }
440                _ => (),
441            }
442        }
443
444        // Reset escaping state
445        let extra = match self {
446            // output the '~' that was previously consumed
447            Self::Escape => Some(b'~'),
448            _ => None,
449        };
450        if newline {
451            *self = Self::Newline
452        } else {
453            *self = Self::Idle
454        }
455
456        EscapeAction::Output { extra }
457    }
458}
459
460#[cfg(test)]
461pub(crate) mod tests {
462    use crate::cmdline_client::*;
463
464    #[test]
465    fn escaping() {
466        // None expect_action is shorthand for ::Output
467        let seqs = vec![
468            ("~.", Some(EscapeAction::Terminate), ""),
469            ("\r~.", Some(EscapeAction::Terminate), "\r"),
470            ("~~.", None, "~."),
471            ("~~~.", None, "~~."),
472            ("\r\r~.", Some(EscapeAction::Terminate), "\r\r"),
473            ("a~/~.", None, "a~/~."),
474            ("a~/\r~.", Some(EscapeAction::Terminate), "a~/\r"),
475            ("~\r~.", Some(EscapeAction::Terminate), "~\r"),
476            ("~\r~ ", None, "~\r~ "),
477        ];
478        for (inp, expect_action, expect) in seqs.iter() {
479            let mut out = vec![];
480            let mut esc = Escaper::new();
481            let mut last_action = None;
482            println!("input \"{}\"", inp.escape_default());
483            for i in inp.chars() {
484                let i: u8 = i.try_into().unwrap();
485                let e = esc.escape(&[i]);
486
487                if let EscapeAction::Output { ref extra } = e {
488                    if let Some(extra) = extra {
489                        out.push(*extra);
490                    }
491                    out.push(i)
492                }
493
494                last_action = Some(e);
495            }
496            assert_eq!(out.as_slice(), expect.as_bytes());
497
498            let last_action = last_action.unwrap();
499            if let Some(expect_action) = expect_action {
500                assert_eq!(&last_action, expect_action);
501            } else {
502                match last_action {
503                    EscapeAction::Output { .. } => (),
504                    _ => panic!("Unexpected action {last_action:?}"),
505                }
506            }
507        }
508    }
509}