Skip to main content

xbp_cli/commands/
ssh_session.rs

1use crate::commands::cloudflared_access::{CloudflaredTcpOptions, CloudflaredTunnel};
2use crate::commands::ssh_helpers::{resolve_or_prompt, resolve_or_prompt_password};
3use crate::config::SshConfig;
4use async_ssh2_tokio::client::{AuthMethod, Client, ServerCheckMethod};
5use crossterm::terminal::{disable_raw_mode, enable_raw_mode, size as terminal_size};
6use russh::{ChannelMsg, Sig};
7use std::env;
8use std::net::Ipv4Addr;
9use std::path::PathBuf;
10use tokio::io::{stderr, stdin, stdout, AsyncReadExt, AsyncWriteExt};
11use tokio::signal;
12use tracing::debug;
13
14const DEFAULT_TERM: &str = "xterm-256color";
15
16#[derive(Debug, Clone)]
17pub struct InteractiveShellOptions {
18    pub ssh_host: Option<String>,
19    pub ssh_port: u16,
20    pub ssh_username: Option<String>,
21    pub ssh_password: Option<String>,
22    pub private_key: Option<PathBuf>,
23    pub private_key_passphrase: Option<String>,
24    pub command: Option<String>,
25    pub term: Option<String>,
26    pub no_host_key_check: bool,
27    pub host_key: Option<String>,
28    pub known_hosts_file: Option<PathBuf>,
29    pub cloudflared_hostname: Option<String>,
30    pub cloudflared_binary: Option<PathBuf>,
31    pub cloudflared_destination: Option<String>,
32}
33
34pub async fn run_interactive_shell(
35    options: InteractiveShellOptions,
36    debug_mode: bool,
37) -> Result<(), String> {
38    let mut config = SshConfig::load().map_err(|e| format!("Failed to load SSH config: {}", e))?;
39    let mut config_dirty = false;
40
41    let resolved_host = resolve_or_prompt(
42        options.ssh_host.clone(),
43        &mut config.host,
44        "Enter SSH host: ",
45        &mut config_dirty,
46    )?;
47    let resolved_username = resolve_or_prompt(
48        options.ssh_username.clone(),
49        &mut config.username,
50        "Enter SSH username: ",
51        &mut config_dirty,
52    )?;
53
54    let auth = resolve_auth_method(&options, &mut config, &mut config_dirty)?;
55
56    if config_dirty {
57        config
58            .save()
59            .map_err(|e| format!("Failed to save SSH config: {}", e))?;
60    }
61
62    let mut cloudflared = if let Some(hostname) = options.cloudflared_hostname.as_deref() {
63        Some(
64            CloudflaredTunnel::start(
65                CloudflaredTcpOptions {
66                    hostname: hostname.to_string(),
67                    listener: None,
68                    destination: options.cloudflared_destination.clone(),
69                    binary_path: options.cloudflared_binary.clone(),
70                },
71                debug_mode,
72            )
73            .await?,
74        )
75    } else {
76        None
77    };
78
79    let server_check = resolve_server_check(&options, cloudflared.is_some());
80    let term = resolve_term_value(options.term.clone());
81
82    let connect_result = if let Some(tunnel) = cloudflared.as_ref() {
83        if debug_mode {
84            debug!(
85                "SSH shell via cloudflared => tunnel host: {}, local port: {}, remote identity: {}, user: {}",
86                tunnel.hostname, tunnel.local_port, resolved_host, resolved_username
87            );
88        }
89
90        Client::connect(
91            (Ipv4Addr::LOCALHOST, tunnel.local_port),
92            resolved_username.as_str(),
93            auth,
94            server_check,
95        )
96        .await
97        .map_err(|e| format!("SSH connection failed: {}", e))
98    } else {
99        if debug_mode {
100            debug!(
101                "SSH shell direct => host: {}, port: {}, user: {}",
102                resolved_host, options.ssh_port, resolved_username
103            );
104        }
105
106        Client::connect(
107            (resolved_host.as_str(), options.ssh_port),
108            resolved_username.as_str(),
109            auth,
110            server_check,
111        )
112        .await
113        .map_err(|e| format!("SSH connection failed: {}", e))
114    };
115
116    let client = match connect_result {
117        Ok(client) => client,
118        Err(err) => {
119            if let Some(tunnel) = cloudflared.as_mut() {
120                tunnel.shutdown().await;
121            }
122            return Err(err);
123        }
124    };
125
126    let session_result =
127        run_channel_session(&client, &term, options.command.as_deref(), debug_mode).await;
128    if let Err(err) = client.disconnect().await {
129        debug!("Failed to cleanly disconnect SSH session: {}", err);
130    }
131    if let Some(tunnel) = cloudflared.as_mut() {
132        tunnel.shutdown().await;
133    }
134
135    session_result
136}
137
138fn resolve_auth_method(
139    options: &InteractiveShellOptions,
140    config: &mut SshConfig,
141    config_dirty: &mut bool,
142) -> Result<AuthMethod, String> {
143    if let Some(key_path) = options.private_key.as_deref() {
144        return Ok(AuthMethod::with_key_file(
145            key_path,
146            options.private_key_passphrase.as_deref(),
147        ));
148    }
149
150    let password = resolve_or_prompt_password(
151        options.ssh_password.clone(),
152        &mut config.password,
153        "Enter SSH password: ",
154        config_dirty,
155    )?;
156    Ok(AuthMethod::with_password(&password))
157}
158
159fn resolve_server_check(
160    options: &InteractiveShellOptions,
161    using_cloudflared: bool,
162) -> ServerCheckMethod {
163    if let Some(host_key) = options.host_key.as_deref().map(str::trim) {
164        if !host_key.is_empty() {
165            return ServerCheckMethod::with_public_key(host_key);
166        }
167    }
168
169    if options.no_host_key_check {
170        return ServerCheckMethod::NoCheck;
171    }
172
173    if let Some(path) = options.known_hosts_file.as_deref() {
174        return ServerCheckMethod::with_known_hosts_file(&path.to_string_lossy());
175    }
176
177    if using_cloudflared {
178        return ServerCheckMethod::NoCheck;
179    }
180
181    ServerCheckMethod::DefaultKnownHostsFile
182}
183
184fn resolve_term_value(explicit_term: Option<String>) -> String {
185    explicit_term
186        .and_then(normalize_optional_string)
187        .or_else(|| env::var("TERM").ok().and_then(normalize_optional_string))
188        .unwrap_or_else(|| DEFAULT_TERM.to_string())
189}
190
191fn normalize_optional_string(value: String) -> Option<String> {
192    let trimmed = value.trim();
193    if trimmed.is_empty() {
194        None
195    } else {
196        Some(trimmed.to_string())
197    }
198}
199
200async fn run_channel_session(
201    client: &Client,
202    term: &str,
203    command: Option<&str>,
204    debug_mode: bool,
205) -> Result<(), String> {
206    let channel = client
207        .get_channel()
208        .await
209        .map_err(|e| format!("Failed to open SSH channel: {}", e))?;
210
211    let (cols, rows) = terminal_size().unwrap_or((120, 32));
212    channel
213        .request_pty(false, term, u32::from(cols), u32::from(rows), 0, 0, &[])
214        .await
215        .map_err(|e| format!("Failed to request remote PTY: {}", e))?;
216
217    if let Some(command) = command {
218        channel
219            .exec(true, command)
220            .await
221            .map_err(|e| format!("Failed to execute remote command: {}", e))?;
222    } else {
223        channel
224            .request_shell(true)
225            .await
226            .map_err(|e| format!("Failed to start remote shell: {}", e))?;
227    }
228
229    if debug_mode {
230        debug!(
231            "SSH channel ready => term: {}, cols: {}, rows: {}, command: {}",
232            term,
233            cols,
234            rows,
235            command.unwrap_or("<login-shell>")
236        );
237    }
238
239    let _raw_mode = RawModeGuard::enable()?;
240    stream_interactive_channel(channel).await
241}
242
243async fn stream_interactive_channel(
244    mut channel: russh::Channel<russh::client::Msg>,
245) -> Result<(), String> {
246    let mut stdin = stdin();
247    let mut stdout = stdout();
248    let mut stderr = stderr();
249    let mut read_buf = [0_u8; 8192];
250    let mut exit_status: Option<u32> = None;
251    let mut stdin_closed = false;
252    let ctrl_c = signal::ctrl_c();
253    tokio::pin!(ctrl_c);
254
255    loop {
256        tokio::select! {
257            read_result = stdin.read(&mut read_buf), if !stdin_closed => {
258                match read_result {
259                    Ok(0) => {
260                        stdin_closed = true;
261                        channel
262                            .eof()
263                            .await
264                            .map_err(|e| format!("Failed to close remote stdin: {}", e))?;
265                    }
266                    Ok(read_len) => {
267                        channel
268                            .data(&read_buf[..read_len])
269                            .await
270                            .map_err(|e| format!("Failed to send SSH input: {}", e))?;
271                    }
272                    Err(err) => return Err(format!("Failed to read terminal input: {}", err)),
273                }
274            }
275            msg = channel.wait() => match msg {
276                Some(ChannelMsg::Data { ref data }) => {
277                    stdout
278                        .write_all(data)
279                        .await
280                        .map_err(|e| format!("Failed to write remote stdout: {}", e))?;
281                    stdout
282                        .flush()
283                        .await
284                        .map_err(|e| format!("Failed to flush stdout: {}", e))?;
285                }
286                Some(ChannelMsg::ExtendedData { ref data, ext }) => {
287                    if ext == 1 {
288                        stderr
289                            .write_all(data)
290                            .await
291                            .map_err(|e| format!("Failed to write remote stderr: {}", e))?;
292                        stderr
293                            .flush()
294                            .await
295                            .map_err(|e| format!("Failed to flush stderr: {}", e))?;
296                    }
297                }
298                Some(ChannelMsg::ExitStatus { exit_status: status }) => {
299                    exit_status = Some(status);
300                }
301                Some(ChannelMsg::ExitSignal { signal_name, .. }) => {
302                    if exit_status.is_none() {
303                        exit_status = Some(signal_to_exit_status(&signal_name));
304                    }
305                }
306                Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) | None => {
307                    break;
308                }
309                Some(_) => {}
310            },
311            _ = &mut ctrl_c => {
312                channel
313                    .signal(Sig::INT)
314                    .await
315                    .map_err(|e| format!("Failed to send interrupt signal: {}", e))?;
316            }
317        }
318    }
319
320    match exit_status {
321        Some(0) | None => Ok(()),
322        Some(status) => Err(format!("Remote shell exited with status: {}", status)),
323    }
324}
325
326fn signal_to_exit_status(signal: &Sig) -> u32 {
327    match signal {
328        Sig::INT => 130,
329        Sig::TERM => 143,
330        Sig::QUIT => 131,
331        Sig::KILL => 137,
332        Sig::HUP => 129,
333        Sig::PIPE => 141,
334        _ => 128,
335    }
336}
337
338struct RawModeGuard;
339
340impl RawModeGuard {
341    fn enable() -> Result<Self, String> {
342        enable_raw_mode().map_err(|e| format!("Failed to enable raw terminal mode: {}", e))?;
343        Ok(Self)
344    }
345}
346
347impl Drop for RawModeGuard {
348    fn drop(&mut self) {
349        let _ = disable_raw_mode();
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::{
356        normalize_optional_string, resolve_server_check, resolve_term_value,
357        InteractiveShellOptions, DEFAULT_TERM,
358    };
359    use async_ssh2_tokio::client::ServerCheckMethod;
360    use std::env;
361    use std::path::PathBuf;
362
363    fn base_options() -> InteractiveShellOptions {
364        InteractiveShellOptions {
365            ssh_host: Some("prod.example.com".to_string()),
366            ssh_port: 22,
367            ssh_username: Some("deploy".to_string()),
368            ssh_password: Some("secret".to_string()),
369            private_key: None,
370            private_key_passphrase: None,
371            command: None,
372            term: None,
373            no_host_key_check: false,
374            host_key: None,
375            known_hosts_file: None,
376            cloudflared_hostname: None,
377            cloudflared_binary: None,
378            cloudflared_destination: None,
379        }
380    }
381
382    #[test]
383    fn explicit_host_key_wins() {
384        let mut options = base_options();
385        options.host_key = Some("AAAAB3NzaC1yc2EAAAADAQABAAABAQDc".to_string());
386        options.no_host_key_check = true;
387
388        let server_check = resolve_server_check(&options, true);
389        assert!(matches!(server_check, ServerCheckMethod::PublicKey(_)));
390    }
391
392    #[test]
393    fn cloudflared_defaults_to_no_check_without_pin() {
394        let options = base_options();
395        let server_check = resolve_server_check(&options, true);
396        assert!(matches!(server_check, ServerCheckMethod::NoCheck));
397    }
398
399    #[test]
400    fn known_hosts_file_is_used_for_direct_ssh() {
401        let mut options = base_options();
402        options.known_hosts_file = Some(PathBuf::from("C:/Users/floris/.ssh/known_hosts"));
403
404        let server_check = resolve_server_check(&options, false);
405        assert!(matches!(server_check, ServerCheckMethod::KnownHostsFile(_)));
406    }
407
408    #[test]
409    fn resolve_term_prefers_explicit_value() {
410        let term = resolve_term_value(Some("screen-256color".to_string()));
411        assert_eq!(term, "screen-256color");
412    }
413
414    #[test]
415    fn resolve_term_falls_back_to_default() {
416        let env_term = env::var("TERM").ok();
417        unsafe {
418            env::remove_var("TERM");
419        }
420
421        let term = resolve_term_value(None);
422
423        if let Some(value) = env_term {
424            unsafe {
425                env::set_var("TERM", value);
426            }
427        }
428
429        assert_eq!(term, DEFAULT_TERM);
430    }
431
432    #[test]
433    fn normalize_optional_string_trims_and_filters_empty_values() {
434        assert_eq!(
435            normalize_optional_string("  cloudflared.example.com ".to_string()),
436            Some("cloudflared.example.com".to_string())
437        );
438        assert_eq!(normalize_optional_string("   ".to_string()), None);
439    }
440}