1use crate::commands::cloudflared_access::{CloudflaredTcpOptions, CloudflaredTunnel};
2use crate::commands::ssh_helpers::{resolve_or_prompt, resolve_or_prompt_password};
3use crate::config::SshConfig;
4use async_ssh2_tokio::client::{AuthMethod, Client, ServerCheckMethod};
5use crossterm::terminal::{disable_raw_mode, enable_raw_mode, size as terminal_size};
6use russh::{ChannelMsg, Sig};
7use std::env;
8use std::net::Ipv4Addr;
9use std::path::PathBuf;
10use tokio::io::{stderr, stdin, stdout, AsyncReadExt, AsyncWriteExt};
11use tokio::signal;
12use tracing::debug;
13
14const DEFAULT_TERM: &str = "xterm-256color";
15
16#[derive(Debug, Clone)]
17pub struct InteractiveShellOptions {
18 pub ssh_host: Option<String>,
19 pub ssh_port: u16,
20 pub ssh_username: Option<String>,
21 pub ssh_password: Option<String>,
22 pub private_key: Option<PathBuf>,
23 pub private_key_passphrase: Option<String>,
24 pub command: Option<String>,
25 pub term: Option<String>,
26 pub no_host_key_check: bool,
27 pub host_key: Option<String>,
28 pub known_hosts_file: Option<PathBuf>,
29 pub cloudflared_hostname: Option<String>,
30 pub cloudflared_binary: Option<PathBuf>,
31 pub cloudflared_destination: Option<String>,
32}
33
34pub async fn run_interactive_shell(
35 options: InteractiveShellOptions,
36 debug_mode: bool,
37) -> Result<(), String> {
38 let mut config = SshConfig::load().map_err(|e| format!("Failed to load SSH config: {}", e))?;
39 let mut config_dirty = false;
40
41 let resolved_host = resolve_or_prompt(
42 options.ssh_host.clone(),
43 &mut config.host,
44 "Enter SSH host: ",
45 &mut config_dirty,
46 )?;
47 let resolved_username = resolve_or_prompt(
48 options.ssh_username.clone(),
49 &mut config.username,
50 "Enter SSH username: ",
51 &mut config_dirty,
52 )?;
53
54 let auth = resolve_auth_method(&options, &mut config, &mut config_dirty)?;
55
56 if config_dirty {
57 config
58 .save()
59 .map_err(|e| format!("Failed to save SSH config: {}", e))?;
60 }
61
62 let mut cloudflared = if let Some(hostname) = options.cloudflared_hostname.as_deref() {
63 Some(
64 CloudflaredTunnel::start(
65 CloudflaredTcpOptions {
66 hostname: hostname.to_string(),
67 listener: None,
68 destination: options.cloudflared_destination.clone(),
69 binary_path: options.cloudflared_binary.clone(),
70 },
71 debug_mode,
72 )
73 .await?,
74 )
75 } else {
76 None
77 };
78
79 let server_check = resolve_server_check(&options, cloudflared.is_some());
80 let term = resolve_term_value(options.term.clone());
81
82 let connect_result = if let Some(tunnel) = cloudflared.as_ref() {
83 if debug_mode {
84 debug!(
85 "SSH shell via cloudflared => tunnel host: {}, local port: {}, remote identity: {}, user: {}",
86 tunnel.hostname, tunnel.local_port, resolved_host, resolved_username
87 );
88 }
89
90 Client::connect(
91 (Ipv4Addr::LOCALHOST, tunnel.local_port),
92 resolved_username.as_str(),
93 auth,
94 server_check,
95 )
96 .await
97 .map_err(|e| format!("SSH connection failed: {}", e))
98 } else {
99 if debug_mode {
100 debug!(
101 "SSH shell direct => host: {}, port: {}, user: {}",
102 resolved_host, options.ssh_port, resolved_username
103 );
104 }
105
106 Client::connect(
107 (resolved_host.as_str(), options.ssh_port),
108 resolved_username.as_str(),
109 auth,
110 server_check,
111 )
112 .await
113 .map_err(|e| format!("SSH connection failed: {}", e))
114 };
115
116 let client = match connect_result {
117 Ok(client) => client,
118 Err(err) => {
119 if let Some(tunnel) = cloudflared.as_mut() {
120 tunnel.shutdown().await;
121 }
122 return Err(err);
123 }
124 };
125
126 let session_result =
127 run_channel_session(&client, &term, options.command.as_deref(), debug_mode).await;
128 if let Err(err) = client.disconnect().await {
129 debug!("Failed to cleanly disconnect SSH session: {}", err);
130 }
131 if let Some(tunnel) = cloudflared.as_mut() {
132 tunnel.shutdown().await;
133 }
134
135 session_result
136}
137
138fn resolve_auth_method(
139 options: &InteractiveShellOptions,
140 config: &mut SshConfig,
141 config_dirty: &mut bool,
142) -> Result<AuthMethod, String> {
143 if let Some(key_path) = options.private_key.as_deref() {
144 return Ok(AuthMethod::with_key_file(
145 key_path,
146 options.private_key_passphrase.as_deref(),
147 ));
148 }
149
150 let password = resolve_or_prompt_password(
151 options.ssh_password.clone(),
152 &mut config.password,
153 "Enter SSH password: ",
154 config_dirty,
155 )?;
156 Ok(AuthMethod::with_password(&password))
157}
158
159fn resolve_server_check(
160 options: &InteractiveShellOptions,
161 using_cloudflared: bool,
162) -> ServerCheckMethod {
163 if let Some(host_key) = options.host_key.as_deref().map(str::trim) {
164 if !host_key.is_empty() {
165 return ServerCheckMethod::with_public_key(host_key);
166 }
167 }
168
169 if options.no_host_key_check {
170 return ServerCheckMethod::NoCheck;
171 }
172
173 if let Some(path) = options.known_hosts_file.as_deref() {
174 return ServerCheckMethod::with_known_hosts_file(&path.to_string_lossy());
175 }
176
177 if using_cloudflared {
178 return ServerCheckMethod::NoCheck;
179 }
180
181 ServerCheckMethod::DefaultKnownHostsFile
182}
183
184fn resolve_term_value(explicit_term: Option<String>) -> String {
185 explicit_term
186 .and_then(normalize_optional_string)
187 .or_else(|| env::var("TERM").ok().and_then(normalize_optional_string))
188 .unwrap_or_else(|| DEFAULT_TERM.to_string())
189}
190
191fn normalize_optional_string(value: String) -> Option<String> {
192 let trimmed = value.trim();
193 if trimmed.is_empty() {
194 None
195 } else {
196 Some(trimmed.to_string())
197 }
198}
199
200async fn run_channel_session(
201 client: &Client,
202 term: &str,
203 command: Option<&str>,
204 debug_mode: bool,
205) -> Result<(), String> {
206 let channel = client
207 .get_channel()
208 .await
209 .map_err(|e| format!("Failed to open SSH channel: {}", e))?;
210
211 let (cols, rows) = terminal_size().unwrap_or((120, 32));
212 channel
213 .request_pty(false, term, u32::from(cols), u32::from(rows), 0, 0, &[])
214 .await
215 .map_err(|e| format!("Failed to request remote PTY: {}", e))?;
216
217 if let Some(command) = command {
218 channel
219 .exec(true, command)
220 .await
221 .map_err(|e| format!("Failed to execute remote command: {}", e))?;
222 } else {
223 channel
224 .request_shell(true)
225 .await
226 .map_err(|e| format!("Failed to start remote shell: {}", e))?;
227 }
228
229 if debug_mode {
230 debug!(
231 "SSH channel ready => term: {}, cols: {}, rows: {}, command: {}",
232 term,
233 cols,
234 rows,
235 command.unwrap_or("<login-shell>")
236 );
237 }
238
239 let _raw_mode = RawModeGuard::enable()?;
240 stream_interactive_channel(channel).await
241}
242
243async fn stream_interactive_channel(
244 mut channel: russh::Channel<russh::client::Msg>,
245) -> Result<(), String> {
246 let mut stdin = stdin();
247 let mut stdout = stdout();
248 let mut stderr = stderr();
249 let mut read_buf = [0_u8; 8192];
250 let mut exit_status: Option<u32> = None;
251 let mut stdin_closed = false;
252 let ctrl_c = signal::ctrl_c();
253 tokio::pin!(ctrl_c);
254
255 loop {
256 tokio::select! {
257 read_result = stdin.read(&mut read_buf), if !stdin_closed => {
258 match read_result {
259 Ok(0) => {
260 stdin_closed = true;
261 channel
262 .eof()
263 .await
264 .map_err(|e| format!("Failed to close remote stdin: {}", e))?;
265 }
266 Ok(read_len) => {
267 channel
268 .data(&read_buf[..read_len])
269 .await
270 .map_err(|e| format!("Failed to send SSH input: {}", e))?;
271 }
272 Err(err) => return Err(format!("Failed to read terminal input: {}", err)),
273 }
274 }
275 msg = channel.wait() => match msg {
276 Some(ChannelMsg::Data { ref data }) => {
277 stdout
278 .write_all(data)
279 .await
280 .map_err(|e| format!("Failed to write remote stdout: {}", e))?;
281 stdout
282 .flush()
283 .await
284 .map_err(|e| format!("Failed to flush stdout: {}", e))?;
285 }
286 Some(ChannelMsg::ExtendedData { ref data, ext }) => {
287 if ext == 1 {
288 stderr
289 .write_all(data)
290 .await
291 .map_err(|e| format!("Failed to write remote stderr: {}", e))?;
292 stderr
293 .flush()
294 .await
295 .map_err(|e| format!("Failed to flush stderr: {}", e))?;
296 }
297 }
298 Some(ChannelMsg::ExitStatus { exit_status: status }) => {
299 exit_status = Some(status);
300 }
301 Some(ChannelMsg::ExitSignal { signal_name, .. }) => {
302 if exit_status.is_none() {
303 exit_status = Some(signal_to_exit_status(&signal_name));
304 }
305 }
306 Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) | None => {
307 break;
308 }
309 Some(_) => {}
310 },
311 _ = &mut ctrl_c => {
312 channel
313 .signal(Sig::INT)
314 .await
315 .map_err(|e| format!("Failed to send interrupt signal: {}", e))?;
316 }
317 }
318 }
319
320 match exit_status {
321 Some(0) | None => Ok(()),
322 Some(status) => Err(format!("Remote shell exited with status: {}", status)),
323 }
324}
325
326fn signal_to_exit_status(signal: &Sig) -> u32 {
327 match signal {
328 Sig::INT => 130,
329 Sig::TERM => 143,
330 Sig::QUIT => 131,
331 Sig::KILL => 137,
332 Sig::HUP => 129,
333 Sig::PIPE => 141,
334 _ => 128,
335 }
336}
337
338struct RawModeGuard;
339
340impl RawModeGuard {
341 fn enable() -> Result<Self, String> {
342 enable_raw_mode().map_err(|e| format!("Failed to enable raw terminal mode: {}", e))?;
343 Ok(Self)
344 }
345}
346
347impl Drop for RawModeGuard {
348 fn drop(&mut self) {
349 let _ = disable_raw_mode();
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::{
356 normalize_optional_string, resolve_server_check, resolve_term_value,
357 InteractiveShellOptions, DEFAULT_TERM,
358 };
359 use async_ssh2_tokio::client::ServerCheckMethod;
360 use std::env;
361 use std::path::PathBuf;
362
363 fn base_options() -> InteractiveShellOptions {
364 InteractiveShellOptions {
365 ssh_host: Some("prod.example.com".to_string()),
366 ssh_port: 22,
367 ssh_username: Some("deploy".to_string()),
368 ssh_password: Some("secret".to_string()),
369 private_key: None,
370 private_key_passphrase: None,
371 command: None,
372 term: None,
373 no_host_key_check: false,
374 host_key: None,
375 known_hosts_file: None,
376 cloudflared_hostname: None,
377 cloudflared_binary: None,
378 cloudflared_destination: None,
379 }
380 }
381
382 #[test]
383 fn explicit_host_key_wins() {
384 let mut options = base_options();
385 options.host_key = Some("AAAAB3NzaC1yc2EAAAADAQABAAABAQDc".to_string());
386 options.no_host_key_check = true;
387
388 let server_check = resolve_server_check(&options, true);
389 assert!(matches!(server_check, ServerCheckMethod::PublicKey(_)));
390 }
391
392 #[test]
393 fn cloudflared_defaults_to_no_check_without_pin() {
394 let options = base_options();
395 let server_check = resolve_server_check(&options, true);
396 assert!(matches!(server_check, ServerCheckMethod::NoCheck));
397 }
398
399 #[test]
400 fn known_hosts_file_is_used_for_direct_ssh() {
401 let mut options = base_options();
402 options.known_hosts_file = Some(PathBuf::from("C:/Users/floris/.ssh/known_hosts"));
403
404 let server_check = resolve_server_check(&options, false);
405 assert!(matches!(server_check, ServerCheckMethod::KnownHostsFile(_)));
406 }
407
408 #[test]
409 fn resolve_term_prefers_explicit_value() {
410 let term = resolve_term_value(Some("screen-256color".to_string()));
411 assert_eq!(term, "screen-256color");
412 }
413
414 #[test]
415 fn resolve_term_falls_back_to_default() {
416 let env_term = env::var("TERM").ok();
417 unsafe {
418 env::remove_var("TERM");
419 }
420
421 let term = resolve_term_value(None);
422
423 if let Some(value) = env_term {
424 unsafe {
425 env::set_var("TERM", value);
426 }
427 }
428
429 assert_eq!(term, DEFAULT_TERM);
430 }
431
432 #[test]
433 fn normalize_optional_string_trims_and_filters_empty_values() {
434 assert_eq!(
435 normalize_optional_string(" cloudflared.example.com ".to_string()),
436 Some("cloudflared.example.com".to_string())
437 );
438 assert_eq!(normalize_optional_string(" ".to_string()), None);
439 }
440}