support_kit/hosts/
ssh_session.rs1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::Duration,
5};
6
7use russh::ChannelMsg;
8use tokio::io::AsyncWriteExt;
9
10use crate::SshError;
11
12use super::{HostDetails, SshConnection};
13
14pub struct SshSession {
15 pub connection: russh::client::Handle<SshConnection>,
16}
17
18impl SshSession {
19 #[tracing::instrument(skip(host), level = "debug")]
20 pub async fn connect(host: &HostDetails) -> Result<Self, SshError> {
21 let config = Arc::new(russh::client::Config {
22 inactivity_timeout: Some(Duration::from_secs(5)),
23 ..<_>::default()
24 });
25
26 let mut session =
27 russh::client::connect(config, (host.address.as_ref(), host.port), SshConnection)
28 .await?;
29
30 tracing::debug!("canonicalizing path to key: {path}", path = host.auth);
31 let path = expand_tilde(&host.auth).ok_or(SshError::InvalidPath(host.auth.clone()))?;
32
33 let key_pair = russh::keys::load_secret_key(path, None)?;
34 let auth_res = session
35 .authenticate_publickey(&host.user, Arc::new(key_pair))
36 .await?;
37
38 if !auth_res {
39 return Err(SshError::AuthenticationFailed);
40 }
41
42 tracing::debug!("ssh session established: {address}", address = host.address);
43
44 Ok(SshSession {
45 connection: session,
46 })
47 }
48
49 #[tracing::instrument(skip(self, command), level = "debug")]
50 pub async fn run_cmd<T>(&self, command: Vec<T>) -> Result<(), SshError>
51 where
52 T: AsRef<str>,
53 {
54 let mut channel = self.connection.channel_open_session().await?;
55 let command = command
56 .into_iter()
57 .map(|x| shell_escape::escape(x.as_ref().to_owned().into()))
58 .collect::<Vec<_>>()
59 .join(" ");
60
61 channel.exec(true, command).await?;
62
63 let mut code = None;
64 let mut stdout = tokio::io::stdout();
65
66 loop {
67 let Some(msg) = channel.wait().await else {
69 tracing::trace!("channel closed");
70 break;
71 };
72
73 match msg {
74 ChannelMsg::Data { ref data } => {
76 tracing::trace!(
77 "received data: {data}",
78 data = String::from_utf8_lossy(data)
79 );
80 stdout.write_all(data).await?;
81 stdout.flush().await?;
82 }
83 ChannelMsg::ExitStatus { exit_status } => {
85 tracing::trace!("exit status: {exit_status}", exit_status = exit_status);
86 code = Some(exit_status);
87 }
89 other => {
90 tracing::trace!("unhandled channel message: {:?}", other);
91 }
92 }
93 }
94
95 channel.close().await?;
97
98 if let Some(code) = code {
101 println!("Exit code: {}", code);
102 }
103
104 Ok(())
105 }
106}
107
108#[tracing::instrument(skip(path_user_input), level = "trace")]
111fn expand_tilde<P: AsRef<Path>>(path_user_input: P) -> Option<PathBuf> {
112 let path = path_user_input.as_ref();
113 if !path.starts_with("~") {
114 return Some(path.to_path_buf());
115 }
116 if path == Path::new("~") {
117 return dirs::home_dir();
118 }
119 dirs::home_dir().map(|mut home| {
120 if home == Path::new("/") {
121 path.strip_prefix("~").unwrap().to_path_buf()
124 } else {
125 home.push(path.strip_prefix("~/").unwrap());
126 home
127 }
128 })
129}
130
131#[test]
132fn test_expand_tilde() {
133 let home = std::env::var("HOME").unwrap();
136 let projects = PathBuf::from(format!("{}/Projects", home));
137 assert_eq!(expand_tilde("~/Projects"), Some(projects));
138 assert_eq!(expand_tilde("/foo/bar"), Some("/foo/bar".into()));
139 assert_eq!(
140 expand_tilde("~alice/projects"),
141 Some("~alice/projects".into())
142 );
143}