1#![doc(html_root_url = "https://docs.rs/spurs/0.9.2")]
14
15use std::{
16 io::Read,
17 net::{SocketAddr, TcpStream, ToSocketAddrs},
18 path::{Path, PathBuf},
19 sync::{Arc, Mutex},
20 thread::JoinHandle,
21 time::Duration,
22};
23
24use log::{debug, info, trace};
25
26use ssh2::Session;
27
28const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
30
31#[derive(Debug, PartialEq, Eq)]
32pub struct SshCommand {
33 cmd: String,
34 cwd: Option<PathBuf>,
35 use_bash: bool,
36 allow_error: bool,
37 dry_run: bool,
38 no_pty: bool,
39}
40
41#[derive(Debug)]
42pub struct SshOutput {
43 pub stdout: String,
44 pub stderr: String,
45}
46
47#[derive(Debug)]
49pub enum SshError {
50 KeyNotFound { file: String },
52
53 AuthFailed { key: std::path::PathBuf },
55
56 NonZeroExit { cmd: String, exit: i32 },
58
59 SshError { error: ssh2::Error },
61
62 IoError { error: std::io::Error },
64}
65
66pub struct SshShell {
68 tcp: TcpStream,
70 username: String,
71 key: PathBuf,
72 remote_name: String, remote: SocketAddr,
74 sess: Arc<Mutex<Session>>,
75 dry_run_mode: bool,
76}
77
78pub struct SshSpawnHandle {
80 thread_handle: JoinHandle<(SshShell, Result<SshOutput, SshError>)>,
81}
82
83pub trait Execute: Sized {
85 fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError>;
89
90 fn duplicate(&self) -> Result<Self, SshError>;
94
95 fn reconnect(&mut self) -> Result<(), SshError>;
97}
98
99impl std::fmt::Display for SshError {
100 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
101 match self {
102 SshError::KeyNotFound { file } => write!(f, "no such key: {}", file),
103 SshError::AuthFailed { key } => {
104 write!(f, "authentication failed with private key: {:?}", key)
105 }
106 SshError::NonZeroExit { cmd, exit } => {
107 write!(f, "non-zero exit ({}) for command: {}", exit, cmd)
108 }
109 SshError::SshError { error } => write!(f, "{}", error),
110 SshError::IoError { error } => write!(f, "{}", error),
111 }
112 }
113}
114
115impl std::error::Error for SshError {}
116
117impl std::convert::From<ssh2::Error> for SshError {
118 fn from(error: ssh2::Error) -> Self {
119 SshError::SshError { error }
120 }
121}
122
123impl std::convert::From<std::io::Error> for SshError {
124 fn from(error: std::io::Error) -> Self {
125 SshError::IoError { error }
126 }
127}
128
129impl SshCommand {
130 pub fn new(cmd: &str) -> Self {
132 SshCommand {
133 cmd: cmd.to_owned(),
134 cwd: None,
135 use_bash: false,
136 allow_error: false,
137 dry_run: false,
138 no_pty: false,
139 }
140 }
141
142 pub fn cwd<P: AsRef<Path>>(self, cwd: P) -> Self {
144 SshCommand {
145 cwd: Some(cwd.as_ref().to_owned()),
146 ..self
147 }
148 }
149
150 pub fn use_bash(self) -> Self {
152 SshCommand {
153 use_bash: true,
154 ..self
155 }
156 }
157
158 pub fn allow_error(self) -> Self {
160 SshCommand {
161 allow_error: true,
162 ..self
163 }
164 }
165
166 pub fn dry_run(self, is_dry: bool) -> Self {
169 SshCommand {
170 dry_run: is_dry,
171 ..self
172 }
173 }
174
175 pub fn no_pty(self) -> Self {
180 SshCommand {
181 no_pty: true,
182 ..self
183 }
184 }
185
186 #[cfg(any(test, feature = "test"))]
188 pub fn make_cmd(
189 cmd: &str,
190 cwd: Option<PathBuf>,
191 use_bash: bool,
192 allow_error: bool,
193 dry_run: bool,
194 no_pty: bool,
195 ) -> Self {
196 SshCommand {
197 cmd: cmd.into(),
198 cwd,
199 use_bash,
200 allow_error,
201 dry_run,
202 no_pty,
203 }
204 }
205
206 #[cfg(any(test, feature = "test"))]
208 pub fn cmd(&self) -> &str {
209 &self.cmd
210 }
211}
212
213impl SshShell {
214 pub fn with_default_key<A: ToSocketAddrs + std::fmt::Debug>(
221 username: &str,
222 remote: A,
223 ) -> Result<Self, SshError> {
224 const DEFAULT_KEY_SUFFIX: &str = ".ssh/id_rsa";
225 let home = if let Some(home) = dirs::home_dir() {
226 home
227 } else {
228 return Err(SshError::KeyNotFound {
229 file: DEFAULT_KEY_SUFFIX.into(),
230 }
231 .into());
232 };
233
234 SshShell::with_key(username, remote, home.join(DEFAULT_KEY_SUFFIX))
235 }
236
237 pub fn with_any_key<A: Copy + ToSocketAddrs + std::fmt::Debug>(
244 username: &str,
245 remote: A,
246 ) -> Result<Self, SshError> {
247 const DEFAULT_KEY_DIR: &str = ".ssh/";
248 let home = if let Some(home) = dirs::home_dir() {
249 home
250 } else {
251 return Err(SshError::KeyNotFound {
252 file: DEFAULT_KEY_DIR.into(),
253 });
254 };
255 let key_dir = home.join(DEFAULT_KEY_DIR);
256
257 for entry in std::fs::read_dir(&key_dir)? {
258 let entry = entry?;
259 let name = entry.file_name().into_string().unwrap();
260
261 if !name.ends_with(".pub") {
263 continue;
264 }
265
266 let (priv_key, _) = name.split_at(name.len() - 4);
267 let shell = SshShell::with_key(username, remote, key_dir.join(priv_key));
268
269 if shell.is_ok() {
270 return shell;
271 }
272 }
273
274 Err(SshError::KeyNotFound {
275 file: DEFAULT_KEY_DIR.into(),
276 })
277 }
278
279 pub fn with_key<A: ToSocketAddrs + std::fmt::Debug, P: AsRef<Path>>(
286 username: &str,
287 remote: A,
288 key: P,
289 ) -> Result<Self, SshError> {
290 info!("New SSH shell: {}@{:?}", username, remote);
291 debug!("Using key: {:?}", key.as_ref());
292
293 debug!("Create new TCP stream...");
294
295 let tcp = TcpStream::connect(&remote)?;
297 tcp.set_read_timeout(Some(DEFAULT_TIMEOUT))?;
298 tcp.set_write_timeout(Some(DEFAULT_TIMEOUT))?;
299 let remote_name = format!("{:?}", remote);
300 let remote = remote.to_socket_addrs().unwrap().next().unwrap();
301
302 debug!("Create new SSH session...");
303
304 let mut sess = Session::new().unwrap();
306 sess.handshake(&tcp)?;
307 trace!("SSH session handshook.");
308 sess.userauth_pubkey_file(username, None, key.as_ref(), None)?;
309 if !sess.authenticated() {
310 return Err(SshError::AuthFailed {
311 key: key.as_ref().to_path_buf(),
312 }
313 .into());
314 }
315 trace!("SSH session authenticated.");
316
317 println!(
318 "{}",
319 console::style(format!("{}@{} ({})", username, remote_name, remote))
320 .green()
321 .bold()
322 );
323
324 Ok(SshShell {
325 tcp,
326 username: username.to_owned(),
327 key: key.as_ref().to_owned(),
328 remote_name,
329 remote,
330 sess: Arc::new(Mutex::new(sess)),
331 dry_run_mode: false,
332 })
333 }
334
335 pub fn from_existing(shell: &SshShell) -> Result<Self, SshError> {
341 info!("New SSH shell: {}@{:?}", shell.username, shell.remote);
342 debug!("Using key: {:?}", shell.key);
343
344 debug!("Create new TCP stream...");
345
346 let tcp = TcpStream::connect(&shell.remote)?;
348 tcp.set_read_timeout(Some(DEFAULT_TIMEOUT))?;
349 tcp.set_write_timeout(Some(DEFAULT_TIMEOUT))?;
350 let remote = shell.remote.clone();
351
352 debug!("Create new SSH session...");
353
354 let mut sess = Session::new().unwrap();
356 sess.handshake(&tcp)?;
357 trace!("SSH session handshook.");
358 sess.userauth_pubkey_file(&shell.username, None, shell.key.as_ref(), None)?;
359 if !sess.authenticated() {
360 return Err(SshError::AuthFailed {
361 key: shell.key.clone(),
362 }
363 .into());
364 }
365 trace!("SSH session authenticated.");
366
367 println!(
368 "{}",
369 console::style(format!(
370 "{}@{} ({})",
371 shell.username, shell.remote_name, remote
372 ))
373 .green()
374 .bold()
375 );
376
377 Ok(SshShell {
378 tcp,
379 username: shell.username.clone(),
380 key: shell.key.clone(),
381 remote_name: shell.remote_name.clone(),
382 remote,
383 sess: Arc::new(Mutex::new(sess)),
384 dry_run_mode: false,
385 })
386 }
387
388 pub fn set_dry_run(&mut self, on: bool) {
392 self.dry_run_mode = on;
393 info!(
394 "Toggled dry run mode: {}",
395 if self.dry_run_mode { "on" } else { "off" }
396 );
397 }
398
399 pub fn spawn(&self, cmd: SshCommand) -> Result<SshSpawnHandle, SshError> {
400 debug!("spawn({:?})", cmd);
401 let shell = Self::from_existing(self)?;
402 let cmd = if self.dry_run_mode {
403 cmd.dry_run(true)
404 } else {
405 cmd
406 };
407
408 let thread_handle = std::thread::spawn(move || {
409 let result = shell.run(cmd);
410 (shell, result)
411 });
412
413 debug!("spawned thread for command.");
414
415 Ok(SshSpawnHandle { thread_handle })
416 }
417
418 fn run_with_chan_and_opts(
419 host_and_username: String, mut chan: ssh2::Channel,
421 cmd_opts: SshCommand,
422 ) -> Result<SshOutput, SshError> {
423 debug!("run_with_chan_and_opts({:?})", cmd_opts);
424
425 let SshCommand {
426 cwd,
427 cmd,
428 use_bash,
429 allow_error,
430 dry_run,
431 no_pty,
432 } = cmd_opts;
433
434 let msg = cmd.clone();
437
438 let cmd = if use_bash {
440 format!("bash -c {}", escape_for_bash(&cmd))
441 } else {
442 cmd
443 };
444
445 debug!("After shell escaping: {:?}", cmd);
446
447 let cmd = if let Some(cwd) = &cwd {
448 format!("cd {} ; {}", cwd.display(), cmd)
449 } else {
450 cmd
451 };
452
453 debug!("After cwd: {:?}", cmd);
454
455 if let Some(cwd) = cwd {
457 println!(
458 "{:-<80}\n{}\n{}\n{}",
459 "",
460 console::style(host_and_username).blue(),
461 console::style(cwd.display()).blue(),
462 console::style(msg).yellow().bold()
463 );
464 } else {
465 println!(
466 "{:-<80}\n{}\n{}",
467 "",
468 console::style(host_and_username).blue(),
469 console::style(msg).yellow().bold()
470 );
471 }
472
473 let mut stdout = String::new();
474 let mut stderr = String::new();
475
476 if dry_run {
478 chan.close()?;
479 chan.wait_close()?;
480
481 debug!("Closed channel after dry run.");
482
483 return Ok(SshOutput { stdout, stderr });
484 }
485
486 if !no_pty {
488 chan.request_pty("vt100", None, None)?;
489 debug!("Requested pty.");
490 }
491
492 debug!("Execute command remotely (asynchronous)...");
494 chan.exec(&cmd)?;
495
496 trace!("Read stdout...");
497
498 let mut buf = [0; 256];
500 while chan.read(&mut buf)? > 0 {
501 let out = String::from_utf8_lossy(&buf);
502 let out = out.trim_end_matches('\u{0}');
503 print!("{}", out);
504 stdout.push_str(out);
505
506 buf.iter_mut().for_each(|x| *x = 0);
508 }
509
510 trace!("No more stdout.");
511
512 chan.close()?;
514 chan.wait_close()?;
515
516 debug!("Command completed remotely.");
517
518 buf.iter_mut().for_each(|x| *x = 0);
520
521 trace!("Read stderr...");
522
523 while chan.stderr().read(&mut buf)? > 0 {
525 let err = String::from_utf8_lossy(&buf);
526 let err = err.trim_end_matches('\u{0}');
527 print!("{}", err);
528 stderr.push_str(err);
529
530 buf.iter_mut().for_each(|x| *x = 0);
532 }
533
534 trace!("No more stderr.");
535 debug!("Checking exit status.");
536
537 let exit = chan.exit_status()?;
539 debug!("Exit status: {}", exit);
540 if exit != 0 && !allow_error {
541 return Err(SshError::NonZeroExit { cmd, exit }.into());
542 }
543
544 trace!("Done with command.");
545
546 Ok(SshOutput { stdout, stderr })
548 }
549}
550
551impl Execute for SshShell {
552 fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError> {
553 debug!("run(cmd)");
554 let sess = self.sess.lock().unwrap();
555 debug!("Attempt to crate channel...");
556 let chan = sess.channel_session()?;
557 debug!("Channel created.");
558 let host_and_username = format!("{}@{}", self.username, self.remote_name);
559 let cmd = if self.dry_run_mode {
560 cmd.dry_run(true)
561 } else {
562 cmd
563 };
564 Self::run_with_chan_and_opts(host_and_username, chan, cmd)
565 }
566
567 fn duplicate(&self) -> Result<Self, SshError> {
568 Self::from_existing(self)
569 }
570
571 fn reconnect(&mut self) -> Result<(), SshError> {
572 info!("Reconnect attempt.");
573
574 trace!("Attempt to create new TCP stream...");
575 loop {
576 print!("{}", console::style("Attempt Reconnect ... ").red());
577 match TcpStream::connect_timeout(&self.remote, DEFAULT_TIMEOUT / 2) {
578 Ok(tcp) => {
579 self.tcp = tcp;
580 break;
581 }
582 Err(e) => {
583 trace!("{:?}", e);
584 println!("{}", console::style("failed, retrying").red());
585 std::thread::sleep(DEFAULT_TIMEOUT / 2);
586 }
587 }
588 }
589
590 println!(
591 "{}",
592 console::style("TCP connected, doing SSH handshake").red()
593 );
594
595 debug!("Attempt to create new SSH session...");
597 let mut sess = Session::new().unwrap();
598 sess.handshake(&self.tcp)?;
599 trace!("Handshook!");
600 sess.userauth_pubkey_file(&self.username, None, self.key.as_ref(), None)?;
601 if !sess.authenticated() {
602 return Err(SshError::AuthFailed {
603 key: self.key.clone(),
604 }
605 .into());
606 }
607 trace!("authenticated!");
608
609 let self_sess = Arc::get_mut(&mut self.sess).unwrap().get_mut().unwrap();
612 let _old_sess = std::mem::replace(self_sess, sess);
613
614 println!(
615 "{}",
616 console::style(format!("{}@{}", self.username, self.remote))
617 .green()
618 .bold()
619 );
620
621 Ok(())
622 }
623}
624
625impl std::fmt::Debug for SshShell {
626 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
627 write!(
628 f,
629 "SshShell {{ {}@{:?} dry_run={} key={:?} }}",
630 self.username, self.remote, self.dry_run_mode, self.key
631 )
632 }
633}
634
635impl SshSpawnHandle {
636 pub fn join(self) -> (SshShell, Result<SshOutput, SshError>) {
638 debug!("Blocking on spawned commmand.");
639 let ret = self.thread_handle.join().unwrap();
640 debug!("Spawned commmand complete.");
641 ret
642 }
643}
644
645impl std::fmt::Debug for SshSpawnHandle {
646 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
647 write!(f, "SshSpawnHandle {{ running }}")
648 }
649}
650
651#[macro_export]
663macro_rules! cmd {
664 ($fmt:expr) => {
665 $crate::SshCommand::new(&format!($fmt))
666 };
667 ($fmt:expr, $($arg:tt)*) => {
668 $crate::SshCommand::new(&format!($fmt, $($arg)*))
669 };
670}
671
672fn escape_for_bash(s: &str) -> String {
677 let mut new = String::with_capacity(s.len());
678
679 for c in s.chars() {
681 if c.is_ascii_alphanumeric() {
682 new.push(c);
683 } else {
684 new.push('\\');
685 new.push(c);
686 }
687 }
688
689 new
690}
691
692#[cfg(test)]
697mod test {
698 use crate::{cmd, SshCommand};
699
700 #[test]
701 fn test_cmd_macro() {
702 assert_eq!(cmd!("{} {}", "ls", 3), SshCommand::new("ls 3"));
703 }
704
705 mod test_escape_for_bash {
706 use super::super::escape_for_bash;
707
708 #[test]
709 fn simple() {
710 const TEST_STRING: &str = "ls";
711 assert_eq!(escape_for_bash(TEST_STRING), "ls");
712 }
713
714 #[test]
715 fn more_complex() {
716 use std::process::Command;
717
718 const TEST_STRING: &str =
719 r#""Bob?!", said she, "I though you said 'I can't be there'!""#;
720
721 let out = Command::new("bash")
722 .arg("-c")
723 .arg(&format!("echo {}", escape_for_bash(TEST_STRING)))
724 .output()
725 .unwrap();
726 let out = String::from_utf8(out.stdout).unwrap();
727
728 assert_eq!(out.trim(), TEST_STRING);
729 }
730 }
731}