support_kit/hosts/
ssh_session.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::Arc,
4    time::Duration,
5};
6
7use russh::ChannelMsg;
8use tokio::io::AsyncWriteExt;
9
10use crate::SshError;
11
12use super::{HostDetails, SshConnection};
13
14pub struct SshSession {
15    pub connection: russh::client::Handle<SshConnection>,
16}
17
18impl SshSession {
19    #[tracing::instrument(skip(host), level = "debug")]
20    pub async fn connect(host: &HostDetails) -> Result<Self, SshError> {
21        let config = Arc::new(russh::client::Config {
22            inactivity_timeout: Some(Duration::from_secs(5)),
23            ..<_>::default()
24        });
25
26        let mut session =
27            russh::client::connect(config, (host.address.as_ref(), host.port), SshConnection)
28                .await?;
29
30        tracing::debug!("canonicalizing path to key: {path}", path = host.auth);
31        let path = expand_tilde(&host.auth).ok_or(SshError::InvalidPath(host.auth.clone()))?;
32
33        let key_pair = russh::keys::load_secret_key(path, None)?;
34        let auth_res = session
35            .authenticate_publickey(&host.user, Arc::new(key_pair))
36            .await?;
37
38        if !auth_res {
39            return Err(SshError::AuthenticationFailed);
40        }
41
42        tracing::debug!("ssh session established: {address}", address = host.address);
43
44        Ok(SshSession {
45            connection: session,
46        })
47    }
48
49    #[tracing::instrument(skip(self, command), level = "debug")]
50    pub async fn run_cmd<T>(&self, command: Vec<T>) -> Result<(), SshError>
51    where
52        T: AsRef<str>,
53    {
54        let mut channel = self.connection.channel_open_session().await?;
55        let command = command
56            .into_iter()
57            .map(|x| shell_escape::escape(x.as_ref().to_owned().into()))
58            .collect::<Vec<_>>()
59            .join(" ");
60
61        channel.exec(true, command).await?;
62
63        let mut code = None;
64        let mut stdout = tokio::io::stdout();
65
66        loop {
67            // There's an event available on the session channel
68            let Some(msg) = channel.wait().await else {
69                tracing::trace!("channel closed");
70                break;
71            };
72
73            match msg {
74                // Write data to the terminal
75                ChannelMsg::Data { ref data } => {
76                    tracing::trace!(
77                        "received data: {data}",
78                        data = String::from_utf8_lossy(data)
79                    );
80                    stdout.write_all(data).await?;
81                    stdout.flush().await?;
82                }
83                // The command has returned an exit code
84                ChannelMsg::ExitStatus { exit_status } => {
85                    tracing::trace!("exit status: {exit_status}", exit_status = exit_status);
86                    code = Some(exit_status);
87                    // cannot leave the loop immediately, there might still be more data to receive
88                }
89                other => {
90                    tracing::trace!("unhandled channel message: {:?}", other);
91                }
92            }
93        }
94
95        // Wait for the channel to close
96        channel.close().await?;
97
98        // report code
99
100        if let Some(code) = code {
101            println!("Exit code: {}", code);
102        }
103
104        Ok(())
105    }
106}
107
108// definitely an easier way to do this, but for now, cribbed from
109// https://stackoverflow.com/questions/54267608/expand-tilde-in-rust-path-idiomatically
110#[tracing::instrument(skip(path_user_input), level = "trace")]
111fn expand_tilde<P: AsRef<Path>>(path_user_input: P) -> Option<PathBuf> {
112    let path = path_user_input.as_ref();
113    if !path.starts_with("~") {
114        return Some(path.to_path_buf());
115    }
116    if path == Path::new("~") {
117        return dirs::home_dir();
118    }
119    dirs::home_dir().map(|mut home| {
120        if home == Path::new("/") {
121            // Corner case: `home` root directory;
122            // don't prepend extra `/`, just drop the tilde.
123            path.strip_prefix("~").unwrap().to_path_buf()
124        } else {
125            home.push(path.strip_prefix("~/").unwrap());
126            home
127        }
128    })
129}
130
131#[test]
132fn test_expand_tilde() {
133    // Should work on your linux box during tests, would fail in stranger
134    // environments!
135    let home = std::env::var("HOME").unwrap();
136    let projects = PathBuf::from(format!("{}/Projects", home));
137    assert_eq!(expand_tilde("~/Projects"), Some(projects));
138    assert_eq!(expand_tilde("/foo/bar"), Some("/foo/bar".into()));
139    assert_eq!(
140        expand_tilde("~alice/projects"),
141        Some("~alice/projects".into())
142    );
143}