sunset_async/
cmdline_client.rs

1use futures::pin_mut;
2#[allow(unused_imports)]
3use log::{debug, error, info, log, trace, warn};
4
5use core::str::FromStr;
6use core::fmt::Debug;
7
8use sunset::{AuthSigMsg, SignKey, OwnedSig, Pty, sshnames};
9use sunset::{BhError, BhResult};
10use sunset::{Error, Result, Runner, SessionCommand};
11use sunset_embassy::*;
12
13use std::collections::VecDeque;
14use embassy_sync::channel::{Channel, Sender, Receiver};
15use embassy_sync::signal::Signal;
16use embedded_io_async::{Read as _, Write as _};
17
18use tokio::io::AsyncReadExt;
19use tokio::io::AsyncWriteExt;
20use tokio::signal::unix::{signal, SignalKind};
21
22use futures::{select_biased, future::Fuse};
23use futures::FutureExt;
24
25use crate::*;
26use crate::AgentClient;
27use crate::{raw_pty, RawPtyGuard};
28use crate::pty::win_size;
29
30#[derive(Debug)]
31enum CmdlineState<'a> {
32    PreAuth,
33    Authed,
34    Opening {
35        io: ChanInOut<'a>,
36        extin: Option<ChanIn<'a>>,
37    },
38    Ready {
39        io: ChanInOut<'a>,
40    },
41}
42
43enum Msg {
44    Authed,
45    Opened,
46    /// The SSH session exited
47    Exited,
48}
49
50/// A commandline client session
51///
52/// This opens a single channel and presents it to the stdin/stdout terminal.
53pub struct CmdlineClient {
54    cmd: SessionCommand<String>,
55    want_pty: bool,
56
57    // to be passed to hooks
58    authkeys: VecDeque<SignKey>,
59    username: String,
60    host: String,
61    port: u16,
62    agent: Option<AgentClient>,
63
64    notify: Channel<SunsetRawMutex, Msg, 1>,
65}
66
67pub struct CmdlineRunner<'a> {
68    state: CmdlineState<'a>,
69
70    want_pty: bool,
71    pty_guard: Option<RawPtyGuard>,
72
73    notify: Receiver<'a, SunsetRawMutex, Msg, 1>,
74}
75
76pub struct CmdlineHooks<'a> {
77    authkeys: VecDeque<SignKey>,
78    username: &'a str,
79    host: &'a str,
80    port: u16,
81    agent: Option<AgentClient>,
82    cmd: &'a SessionCommand<String>,
83    pty: Option<Pty>,
84
85    notify: Sender<'a, SunsetRawMutex, Msg, 1>,
86}
87
88impl CmdlineClient {
89    pub fn new(username: impl AsRef<str>, host: impl AsRef<str>) -> Self {
90        Self {
91            cmd: SessionCommand::Shell,
92            want_pty: false,
93            agent: None,
94
95            notify: Channel::new(),
96
97            username: username.as_ref().into(),
98            host: host.as_ref().into(),
99            port: sshnames::SSH_PORT,
100            authkeys: Default::default(),
101        }
102    }
103
104    /// Splits a `CmdlineClient` into hooks and the runner.
105    ///
106    /// `CmdlineRunner` should be awaited until the session completes.
107    /// `CmdlineHooks` can be used to exit early (and may in future provide
108    /// other functionality).
109    pub fn split(&mut self) -> (CmdlineHooks, CmdlineRunner) {
110
111        let pty = self.make_pty();
112
113        let authkeys = core::mem::replace(&mut self.authkeys, Default::default());
114
115        let runner = CmdlineRunner::new(pty.is_some(), self.notify.receiver());
116
117        let hooks = CmdlineHooks {
118            username: &self.username,
119            host: &self.host,
120            port: self.port,
121            authkeys,
122            agent: self.agent.take(),
123            cmd: &self.cmd,
124            pty,
125            notify: self.notify.sender(),
126        };
127
128        (hooks, runner)
129    }
130
131    pub fn port(&mut self, port: u16) -> &mut Self {
132        self.port = port;
133        self
134    }
135
136    pub fn pty(&mut self) -> &mut Self {
137        self.want_pty = true;
138        self
139    }
140
141    pub fn exec(&mut self, cmd: &str) -> &mut Self {
142        self.cmd = SessionCommand::Exec(cmd.into());
143        self
144    }
145
146    pub fn subsystem(&mut self, subsystem: &str) -> &mut Self {
147        self.cmd = SessionCommand::Subsystem(subsystem.into());
148        self
149    }
150
151    pub fn add_authkey(&mut self, k: SignKey) {
152        self.authkeys.push_back(k)
153    }
154
155    pub fn agent(&mut self, agent: AgentClient) {
156        self.agent = Some(agent)
157    }
158
159    fn make_pty(&mut self) -> Option<Pty> {
160        let mut pty = None;
161        if self.want_pty {
162            match pty::current_pty() {
163                Ok(p) => pty = Some(p),
164                Err(e) => warn!("Failed getting current pty: {e:?}"),
165            }
166
167        }
168        pty
169    }
170
171}
172
173
174impl<'a> CmdlineRunner<'a> {
175    fn new(want_pty: bool, notify: Receiver<'a, SunsetRawMutex, Msg, 1>) -> Self {
176        Self {
177            state: CmdlineState::PreAuth,
178            want_pty,
179            notify,
180            pty_guard: None,
181        }
182    }
183
184    async fn chan_run(io: ChanInOut<'a>,
185        io_err: Option<ChanIn<'a>>,
186        pty_guard: Option<RawPtyGuard>) -> Result<()> {
187        // out
188        let fo = async {
189            let mut io = io.clone();
190            let mut so = crate::stdout().map_err(|_| {
191                Error::msg("opening stdout failed")
192            })?;
193            loop {
194                // TODO buffers
195                let mut buf = [0u8; 1000];
196                let l = io.read(&mut buf).await?;
197                if l == 0 {
198                    break;
199                }
200                so.write_all(&buf[..l]).await.map_err(|_| Error::ChannelEOF)?;
201            }
202            #[allow(unreachable_code)]
203            Ok::<_, sunset::Error>(())
204        };
205
206        // err
207        let fe = async {
208            // if io_err is None we complete immediately
209            if let Some(mut errin) = io_err {
210                let mut eo = crate::stderr_out().map_err(|_e| {
211                    Error::msg("opening stderr failed")
212                })?;
213                loop {
214                    // TODO buffers
215                    let mut buf = [0u8; 1000];
216                    let l = errin.read(&mut buf).await?;
217                    if l == 0 {
218                        break;
219                    }
220                    eo.write_all(&buf[..l]).await.map_err(|_| Error::ChannelEOF)?;
221                }
222                #[allow(unreachable_code)]
223                Ok::<_, sunset::Error>(())
224            } else {
225                Ok(())
226            }
227        };
228
229        let terminate = Signal::<SunsetRawMutex, ()>::new();
230
231        // in
232        let fi = async {
233            let mut io = io.clone();
234            let mut si = crate::stdin().map_err(|_| Error::msg("opening stdin failed"))?;
235            let mut esc = if pty_guard.is_some() {
236                Some(Escaper::new())
237            } else {
238                None
239            };
240
241            loop {
242                // TODO buffers
243                let mut buf = [0u8; 1000];
244                let l = si.read(&mut buf).await.map_err(|_| Error::ChannelEOF)?;
245                if l == 0 {
246                    return Err(Error::ChannelEOF)
247                }
248
249                let buf = &buf[..l];
250
251                if let Some(ref mut esc) = esc {
252                    let a = esc.escape(buf);
253                    match a {
254                        EscapeAction::None => (),
255                        EscapeAction::Output { extra } => {
256                            if let Some(e) = extra {
257                                io.write_all(&[e]).await?;
258                            }
259                            io.write_all(buf).await?;
260                        }
261                        EscapeAction::Terminate => {
262                            info!("Terminated");
263                            terminate.signal(());
264                            return Ok(())
265                        }
266                        EscapeAction::Suspend => {
267                            // disabled for the time being, doesn't resume OK.
268                            // perhaps a bad interaction with dup_async(),
269                            // maybe the new guard needs to be on the dup-ed
270                            // FDs?
271                            ()
272
273                            // pty_guard = None;
274                            // nix::sys::signal::raise(nix::sys::signal::Signal::SIGTSTP)
275                            // .unwrap_or_else(|e| {
276                            //     warn!("Failed to stop: {e:?}");
277                            // });
278                            // // suspended here until resumed externally
279                            // set_pty_guard(&mut pty_guard);
280                            // continue;
281                        }
282                    }
283                } else {
284                    io.write_all(buf).await?;
285                }
286
287            }
288            #[allow(unreachable_code)]
289            Ok::<_, sunset::Error>(())
290        };
291
292        // output needs to complete when the channel is closed
293        let fi = embassy_futures::select::select(fi, io.until_closed());
294
295        // let fo = fo.map(|x| {
296        //     error!("fo done {x:?}");
297        //     x
298        // });
299        // let fi = fi.map(|x| {
300        //     error!("fi done {x:?}");
301        //     x
302        // });
303        // let fe = fe.map(|x| {
304        //     error!("fe done {x:?}");
305        //     x
306        // });
307
308        let io_done = embassy_futures::join::join3(fe, fi, fo);
309        let _ = embassy_futures::select::select(io_done, terminate.wait()).await;
310        // TODO handle errors from the join?
311        Ok(())
312    }
313
314    /// Runs the `CmdlineClient` session to completion.
315    ///
316    /// Performs authentication, requests a shell or command, performs channel IO.
317    /// Will return `Ok` after the session ends normally, or an error.
318    pub async fn run(&mut self, cli: &'a SSHClient<'a>) -> Result<()> {
319        // chanio is only set once a channel is opened below
320        let chanio = Fuse::terminated();
321        pin_mut!(chanio);
322
323        let mut winch_signal = self.want_pty
324            .then(|| signal(SignalKind::window_change()))
325            .transpose()
326            .unwrap_or_else(|_| {
327                warn!("Couldn't watch for window change signals");
328                None
329            });
330
331        loop {
332            let winch_fut = Fuse::terminated();
333            pin_mut!(winch_fut);
334            if let Some(w) = winch_signal.as_mut() {
335                winch_fut.set(w.recv().fuse());
336            }
337
338            select_biased! {
339                msg = self.notify.receive().fuse() => {
340                    match msg {
341                        Msg::Authed => {
342                            if !matches!(self.state, CmdlineState::PreAuth) {
343                                warn!("Unexpected auth success, state {:?}", self.state);
344                                return Ok(())
345                            }
346                            self.state = CmdlineState::Authed;
347                            debug!("Opening a new session channel");
348                            self.open_session(cli).await?;
349                        }
350                        Msg::Opened => {
351                            let st = core::mem::replace(&mut self.state, CmdlineState::Authed);
352                            if let CmdlineState::Opening { io, extin } = st {
353                                let r = Self::chan_run(io.clone(), extin.clone(), self.pty_guard.take())
354                                    .fuse();
355                                chanio.set(r);
356                                self.state = CmdlineState::Ready { io };
357                            } else {
358                                warn!("Unexpected Msg::Opened")
359                            }
360                        }
361                        Msg::Exited => {
362                            trace!("SSH exited, finishing cli loop");
363                            break;
364                        }
365                    }
366                    Ok::<_, sunset::Error>(())
367                },
368
369                e = chanio => {
370                    trace!("chanio finished: {e:?}");
371                    cli.exit().await;
372                    break;
373                }
374
375                _ = winch_fut => {
376                    self.window_change_signal().await;
377                    Ok::<_, sunset::Error>(())
378                }
379            }?
380        }
381
382        Ok(())
383    }
384
385    async fn open_session(&mut self, cli: &'a SSHClient<'a>) -> Result<()> {
386        debug_assert!(matches!(self.state, CmdlineState::Authed));
387
388        let (io, extin) = if self.want_pty {
389            set_pty_guard(&mut self.pty_guard);
390            let io = cli.open_session_pty().await?;
391            (io, None)
392        } else {
393            let (io, extin) = cli.open_session_nopty().await?;
394            (io, Some(extin))
395        };
396        self.state = CmdlineState::Opening { io, extin };
397        Ok(())
398    }
399
400    async fn window_change_signal(&mut self) {
401        let io = match &self.state {
402            CmdlineState::Opening { io, ..} => io,
403            CmdlineState::Ready { io, ..} => io,
404            _ => return,
405        };
406
407        let winch = match win_size() {
408            Ok(w) => w,
409            Err(e) => {
410                debug!("Error getting window size: {e:?}");
411                return;
412            }
413        };
414
415        if let Err(e) = io.term_window_change(winch).await {
416            debug!("window change failed: {e:?}");
417        }
418    }
419}
420
421fn set_pty_guard(pty_guard: &mut Option<RawPtyGuard>) {
422    match raw_pty() {
423        Ok(p) => *pty_guard = Some(p),
424        Err(e) => {
425            warn!("Failed getting raw pty: {e:?}");
426        }
427    }
428}
429
430#[derive(Debug, PartialEq)]
431enum EscapeAction {
432    None,
433    // an extra character of output to prepend
434    Output { extra: Option<u8> },
435    Terminate,
436    Suspend,
437}
438
439#[derive(Debug)]
440enum Escaper {
441    Idle,
442    Newline,
443    Escape,
444}
445
446impl Escaper {
447    fn new() -> Self {
448        // start as if we had received a '\r'
449        Self::Newline
450    }
451
452    /// Handle ~. escape sequences.
453    fn escape(&mut self, buf: &[u8]) -> EscapeAction {
454        // Only handle single input keystrokes. Provides some protection against
455        // pasting escape sequences too.
456
457        let mut newline = false;
458        if buf.len() == 1 {
459            let c = buf[0];
460            newline = c == b'\r';
461
462            match self {
463                Self::Newline if c == b'~' => {
464                    *self = Self::Escape;
465                    return EscapeAction::None
466                }
467                Self::Escape => {
468                    // handle the actual escape character
469                    match c {
470                        b'~' => {
471                            // output the single '~' in buf.
472                            *self = Self::Idle;
473                            return EscapeAction::Output { extra: None }
474                        }
475                        b'.' => {
476                            *self = Self::Idle;
477                            return EscapeAction::Terminate
478                        }
479                        // ctrl-z, suspend
480                        0x1a => {
481                            *self = Self::Idle;
482                            return EscapeAction::Suspend
483                        }
484                        // fall through to reset below
485                        _ => (),
486                    }
487                }
488                _ => (),
489            }
490        }
491
492        // Reset escaping state
493        let extra = match self {
494            // output the '~' that was previously consumed
495            Self::Escape => Some(b'~'),
496            _ => None,
497        };
498        if newline {
499            *self = Self::Newline
500        } else {
501            *self = Self::Idle
502        }
503
504        EscapeAction::Output { extra }
505    }
506}
507
508impl<'a> CmdlineHooks<'a> {
509    /// Notify the `CmdlineClient` that the main SSH session has exited.
510    ///
511    /// This will cause the `CmdlineRunner` to finish flushing output and terminate.
512    pub async fn exited(&mut self) {
513        self.notify.send(Msg::Exited).await
514    }
515}
516
517impl sunset::CliBehaviour for CmdlineHooks<'_> {
518    fn username(&mut self) -> BhResult<sunset::ResponseString> {
519        sunset::ResponseString::from_str(&self.username).map_err(|_| BhError::Fail)
520    }
521
522    fn valid_hostkey(&mut self, key: &sunset::PubKey) -> BhResult<bool> {
523        trace!("checking hostkey for {key:?}");
524
525        match knownhosts::check_known_hosts(self.host, self.port, key) {
526            Ok(()) => Ok(true),
527            Err(e) => {
528                debug!("Error for hostkey: {e:?}");
529                Ok(false)
530            }
531        }
532    }
533
534    fn next_authkey(&mut self) -> BhResult<Option<sunset::SignKey>> {
535        Ok(self.authkeys.pop_front())
536    }
537
538    fn auth_password(
539        &mut self,
540        pwbuf: &mut sunset::ResponseString,
541    ) -> BhResult<bool> {
542        let pw =
543            rpassword::prompt_password(format!("password for {}: ", self.username))
544                .map_err(|e| {
545                    warn!("read_password failed {e:}");
546                    BhError::Fail
547                })?;
548        if pwbuf.push_str(&pw).is_err() {
549            Err(BhError::Fail)
550        } else {
551            Ok(true)
552        }
553    }
554
555    async fn agent_sign(&mut self, key: &SignKey, msg: &AuthSigMsg<'_>) -> BhResult<OwnedSig> {
556        if let Some(ref mut agent) = self.agent {
557            agent.sign_auth(key, msg).await.map_err(|_e| {
558                error!("agent signing failed");
559                BhError::Fail
560            })
561        } else {
562            error!("agent signing wrong");
563            Err(BhError::Fail)
564        }
565    }
566
567    fn authenticated(&mut self) {
568        debug!("Authentication succeeded");
569        // TODO: need better handling, what else could we do?
570        while self.notify.try_send(Msg::Authed).is_err() {
571            warn!("Full notification queue");
572        }
573    }
574
575    async fn session_opened(&mut self, _chan: sunset::ChanNum, opener: &mut sunset::SessionOpener<'_, '_, '_>) -> BhResult<()> {
576        if let Some(p) = self.pty.take() {
577            opener.pty(p)
578        }
579        opener.cmd(self.cmd);
580        self.notify.send(Msg::Opened).await;
581        Ok(())
582    }
583}
584
585impl<'a> Debug for CmdlineHooks<'a> {
586    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
587        f.write_str("CmdlineHooks")
588    }
589}
590
591#[cfg(test)]
592pub(crate) mod tests {
593    use crate::cmdline_client::*;
594
595    #[test]
596    fn escaping() {
597        // None expect_action is shorthand for ::Output
598        let seqs = vec![
599            ("~.", Some(EscapeAction::Terminate), ""),
600            ("\r~.", Some(EscapeAction::Terminate), "\r"),
601            ("~~.", None, "~."),
602            ("~~~.", None, "~~."),
603            ("\r\r~.", Some(EscapeAction::Terminate), "\r\r"),
604            ("a~/~.", None, "a~/~."),
605            ("a~/\r~.", Some(EscapeAction::Terminate), "a~/\r"),
606            ("~\r~.", Some(EscapeAction::Terminate), "~\r"),
607            ("~\r~ ", None, "~\r~ "),
608        ];
609        for (inp, expect_action, expect) in seqs.iter() {
610            let mut out = vec![];
611            let mut esc = Escaper::new();
612            let mut last_action = None;
613            println!("input \"{}\"", inp.escape_default());
614            for i in inp.chars() {
615                let i: u8 = i.try_into().unwrap();
616                let e = esc.escape(&[i]);
617
618                if let EscapeAction::Output { ref extra } = e {
619                    if let Some(extra) = extra {
620                        out.push(*extra);
621                    }
622                    out.push(i)
623                }
624
625                last_action = Some(e);
626            }
627            assert_eq!(out.as_slice(), expect.as_bytes());
628
629            let last_action = last_action.unwrap();
630            if let Some(expect_action) = expect_action {
631                assert_eq!(&last_action, expect_action);
632            } else {
633                match last_action {
634                    EscapeAction::Output { .. } => (),
635                    _ => panic!("Unexpected action {last_action:?}"),
636                }
637            }
638        }
639    }
640
641}
642